From 1a03c8b43a9545821650f94431e5780f44a1f9a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Sep 2019 23:00:40 +0800 Subject: [PATCH 1/5] Support passing all Table metadata in TableProvider --- .../spark/sql/v2/avro/AvroDataSourceV2.scala | 19 +- .../apache/spark/sql/v2/avro/AvroTable.scala | 2 +- .../sql/kafka010/KafkaSourceProvider.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../kafka010/KafkaSourceProviderSuite.scala | 2 +- .../sql/connector/catalog/TableProvider.java | 41 ++- .../connector/catalog/CatalogManager.scala | 19 +- .../spark/sql/connector/catalog/V1Table.scala | 21 +- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../spark/sql/connector/InMemoryTable.scala | 2 + .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/noop/NoopDataSource.scala | 2 +- .../v2/CatalogExtensionForTableProvider.scala | 95 ++++++ .../datasources/v2/DataSourceV2Utils.scala | 22 +- .../datasources/v2/FileDataSourceV2.scala | 3 +- .../execution/datasources/v2/FileTable.scala | 11 +- .../datasources/v2/V2SessionCatalog.scala | 48 +-- .../datasources/v2/csv/CSVDataSourceV2.scala | 19 +- .../datasources/v2/csv/CSVTable.scala | 2 +- .../v2/json/JsonDataSourceV2.scala | 19 +- .../datasources/v2/json/JsonTable.scala | 2 +- .../datasources/v2/orc/OrcDataSourceV2.scala | 20 +- .../datasources/v2/orc/OrcTable.scala | 2 +- .../v2/parquet/ParquetDataSourceV2.scala | 19 +- .../datasources/v2/parquet/ParquetTable.scala | 2 +- .../v2/text/TextDataSourceV2.scala | 20 +- .../datasources/v2/text/TextTable.scala | 2 +- .../sql/execution/streaming/console.scala | 2 +- .../sql/execution/streaming/memory.scala | 2 +- .../sources/RateStreamProvider.scala | 58 ++-- .../sources/TextSocketSourceProvider.scala | 26 +- .../sql/streaming/DataStreamReader.scala | 6 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../connector/JavaAdvancedDataSourceV2.java | 2 +- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../JavaPartitionAwareDataSource.java | 3 +- .../JavaReportStatisticsDataSource.java | 2 +- .../JavaSchemaRequiredDataSource.java | 10 +- .../sql/connector/JavaSimpleDataSourceV2.java | 4 +- ...SourceV2DataFrameSessionCatalogSuite.scala | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- .../connector/DataSourceV2SQLUsingSuite.scala | 286 ++++++++++++++++++ .../sql/connector/DataSourceV2Suite.scala | 22 +- .../FileDataSourceV2FallBackSuite.scala | 4 +- .../connector/SimpleWritableDataSource.scala | 5 +- .../connector/TableCapabilityCheckSuite.scala | 2 +- .../sql/connector/V1WriteFallbackSuite.scala | 6 +- .../sources/TextSocketStreamSuite.scala | 11 +- .../sources/StreamingDataSourceV2Suite.scala | 24 +- 50 files changed, 683 insertions(+), 213 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index c6f52d676422c..98fe9e10c1afb 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.v2.avro +import java.util + import org.apache.spark.sql.avro.AvroFileFormat import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class AvroDataSourceV2 extends FileDataSourceV2 { @@ -29,15 +31,18 @@ class AvroDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "avro" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + AvroTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + AvroTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index 765e5727d944a..952d9f55d88ad 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AvroTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index c15f08d78741d..ad22cd9965755 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -108,8 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParameters)) } - override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { - val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + override def loadTable(properties: java.util.Map[String, String]): KafkaTable = { + val includeHeaders = Option(properties.get(INCLUDE_HEADERS)).map(_.toBoolean).getOrElse(false) new KafkaTable(includeHeaders) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 3ee59e57a6edf..9867f60b50be7 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1142,7 +1142,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} val dsOptions = new CaseInsensitiveStringMap(options.asJava) - val table = provider.getTable(dsOptions) + val table = provider.loadTable(dsOptions) val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) val inputPartitions = stream.planInputPartitions( KafkaSourceOffset(Map(tp -> 0L)), diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala index f7b00b31ebba0..984e100b562f0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala @@ -125,6 +125,6 @@ class KafkaSourceProviderSuite extends SparkFunSuite { private def getKafkaDataSourceScan(options: CaseInsensitiveStringMap): Scan = { val provider = new KafkaSourceProvider() - provider.getTable(options).newScanBuilder(options).build() + provider.loadTable(options).newScanBuilder(options).build() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index e9fd87d0e2d40..c3eef90e3b32c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -17,7 +17,10 @@ package org.apache.spark.sql.connector.catalog; +import java.util.Map; + import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -36,26 +39,32 @@ public interface TableProvider { /** - * Return a {@link Table} instance to do read/write with user-specified options. + * Return a {@link Table} instance with the given table properties to do read/write. + * Implementations should infer the table schema and partitioning. * - * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param properties The properties of the table to load. It should be sufficient to define and + * access a table. The properties map may be {@link CaseInsensitiveStringMap}. */ - Table getTable(CaseInsensitiveStringMap options); + Table loadTable(Map properties); /** - * Return a {@link Table} instance to do read/write with user-specified schema and options. - *

- * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user-specified schema. - *

- * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. - * @param schema the user-specified schema. - * @throws UnsupportedOperationException + * Return a {@link Table} instance with the given table schema, partitioning and properties to do + * read/write . The returned table must report the same schema and partitioning with the given + * ones. + * + * By default this method simply calls {@link #loadTable(Map)}. The returned table may report + * different schema/partitioning and fail the job later. Implementation should override + * this method if they can leverage the given schema and partitioning. + * + * @param schema The schema of the table to load. + * @param partitions The data partitioning of the table to load. + * @param properties The properties of the table to load. It should be sufficient to define and + * access a table. The properties map may be {@link CaseInsensitiveStringMap}. */ - default Table getTable(CaseInsensitiveStringMap options, StructType schema) { - throw new UnsupportedOperationException( - this.getClass().getSimpleName() + " source does not support user-specified schema"); + default Table loadTable( + StructType schema, + Transform[] partitions, + Map properties) { + return loadTable(properties); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index be14b17701276..a7f9728cf2b80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -77,16 +77,15 @@ class CatalogManager( // If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2 // session catalog. Otherwise, return the default session catalog. def v2SessionCatalog: CatalogPlugin = { - conf.getConf(SQLConf.V2_SESSION_CATALOG).map { customV2SessionCatalog => - try { - catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog()) - } catch { - case NonFatal(_) => - logError( - "Fail to instantiate the custom v2 session catalog: " + customV2SessionCatalog) - defaultSessionCatalog - } - }.getOrElse(defaultSessionCatalog) + try { + catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog()) + } catch { + case NonFatal(_) => + logError( + "Fail to instantiate the custom v2 session catalog: " + + conf.getConfString(CatalogManager.SESSION_CATALOG_NAME)) + defaultSessionCatalog + } } private def getDefaultNamespace(c: CatalogPlugin) = c match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index 616c3cf696396..cb23433b989a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.types.StructType * An implementation of catalog v2 `Table` to expose v1 table metadata. */ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { + assert(v1Table.provider.isDefined) + implicit class IdentifierHelper(identifier: TableIdentifier) { def quoted: String = { identifier.database match { @@ -38,7 +40,6 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { Seq(db, identifier.table).map(quote).mkString(".") case _ => quote(identifier.table) - } } @@ -51,20 +52,18 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { } } - def catalogTable: CatalogTable = v1Table - - lazy val options: Map[String, String] = { - v1Table.storage.locationUri match { + override lazy val properties: util.Map[String, String] = { + val pathOption = v1Table.storage.locationUri match { case Some(uri) => - v1Table.storage.properties + ("path" -> uri.toString) + Some("path" -> uri.toString) case _ => - v1Table.storage.properties + None } + val providerOption = "provider" -> v1Table.provider.get + (v1Table.storage.properties ++ v1Table.properties ++ pathOption + providerOption).asJava } - override lazy val properties: util.Map[String, String] = v1Table.properties.asJava - - override lazy val schema: StructType = v1Table.schema + override def schema: StructType = v1Table.schema override lazy val partitioning: Array[Transform] = { val partitions = new mutable.ArrayBuffer[Transform]() @@ -84,5 +83,5 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() - override def toString: String = s"UnresolvedTable($name)" + override def toString: String = s"V1Table($name)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3d28b5e93a17e..bfc7871d6bcdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,7 +1982,8 @@ object SQLConf { "passed the Spark built-in session catalog, so that it may delegate calls to the " + "built-in session catalog.") .stringConf - .createOptional + .createWithDefault( + "org.apache.spark.sql.execution.datasources.v2.CatalogExtensionForTableProvider") val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") .doc("When true, the upcast will be loose and allows string to atomic types.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 414f9d5834868..59299b3157982 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -72,6 +72,8 @@ class InMemoryTable( this } + def clear(): Unit = dataMap.synchronized(dataMap.clear()) + override def capabilities: util.Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b9cc25817d2f3..a5c75eff84c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -216,8 +216,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + case Some(schema) => + DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, source, schema, dsOptions) + case _ => + provider.loadTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3f7016df2eb42..fa86af72a262a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -260,7 +260,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { + provider.loadTable(dsOptions) match { case table: SupportsWrite if table.supports(BATCH_WRITE) => if (partitioningColumns.nonEmpty) { throw new AnalysisException("Cannot write data to TableProvider implementation " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 3f4f29c3e135a..0cba66efb3783 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable + override def loadTable(properties: util.Map[String, String]): Table = NoopTable } private[noop] object NoopTable extends Table with SupportsWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala new file mode 100644 index 0000000000000..4898575691480 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util + +import scala.util.control.NonFatal + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.{DelegatingCatalogExtension, Identifier, Table} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType + +class CatalogExtensionForTableProvider extends DelegatingCatalogExtension { + + private val conf = SQLConf.get + + override def loadTable(ident: Identifier): Table = { + val table = super.loadTable(ident) + tryResolveTableProvider(table) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val provider = properties.getOrDefault("provider", conf.defaultDataSourceName) + val maybeProvider = DataSource.lookupDataSourceV2(provider, conf) + val (actualSchema, actualPartitioning) = if (maybeProvider.isDefined && schema.isEmpty) { + // A sanity check. The parser should guarantee it. + assert(partitions.isEmpty) + // If `CREATE TABLE ... USING` does not specify table metadata, get the table metadata from + // data source first. + val table = maybeProvider.get.loadTable(properties) + table.schema() -> table.partitioning() + } else { + schema -> partitions + } + super.createTable(ident, actualSchema, actualPartitioning, properties) + // call `loadTable` to make sure the schema/partitioning specified in `CREATE TABLE ... USING` + // matches the actual data schema/partitioning. If error happens during table loading, drop + // the table. + try { + loadTable(ident) + } catch { + case NonFatal(e) => + dropTable(ident) + throw e + } + } + + private def tryResolveTableProvider(table: Table): Table = { + val providerName = table.properties().get("provider") + assert(providerName != null) + DataSource.lookupDataSourceV2(providerName, conf).map { provider => + // TODO: support file source v2 in CREATE TABLE USING. + if (provider.isInstanceOf[FileDataSourceV2]) { + table + } else { + val loaded = provider.loadTable(table.schema, table.partitioning, table.properties) + if (loaded.schema().asNullable != table.schema.asNullable) { + throw new AnalysisException(s"Table provider '$providerName' returns a table " + + "which has inappropriate schema:\n" + + s"schema in Spark meta-store:\t${table.schema}\n" + + s"schema from table provider:\t${loaded.schema()}") + } + if (!loaded.partitioning().sameElements(table.partitioning)) { + throw new AnalysisException(s"Table provider '$providerName' returns a table " + + "which has inappropriate partitioning:\n" + + s"partitioning in Spark meta-store:\t${table.partitioning.mkString(", ")}\n" + + s"partitioning from table provider:\t${loaded.partitioning.mkString(", ")}") + } + loaded + } + }.getOrElse(table) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 52294ae2cb851..13d921dff0e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -20,8 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern import org.apache.spark.internal.Logging -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, TableProvider} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object DataSourceV2Utils extends Logging { @@ -57,4 +60,21 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } } + + def loadTableWithUserSpecifiedSchema( + provider: TableProvider, + providerName: String, + schema: StructType, + options: CaseInsensitiveStringMap): Table = { + // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified + // partitioning. + val table = provider.loadTable(schema, Array.empty, options) + if (table.schema().asNullable != schema.asNullable) { + throw new AnalysisException(s"Table provider '$providerName' returns a table which " + + "has inappropriate schema:\n" + + s"user-specified schema:\t$schema\n" + + s"schema from table provider:\t${table.schema()}") + } + table + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index e0091293d1669..8eacd6df0c371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils /** @@ -41,7 +40,7 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { lazy val sparkSession = SparkSession.active - protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = { + protected def getPaths(map: java.util.Map[String, String]): Seq[String] = { val objectMapper = new ObjectMapper() val paths = Option(map.get("paths")).map { pathStr => objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 5329e09916bd6..30f3416892af4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -34,15 +34,20 @@ import org.apache.spark.sql.util.SchemaUtils abstract class FileTable( sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType]) extends Table with SupportsRead with SupportsWrite { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + private def caseSensitiveOptions = options match { + case m: CaseInsensitiveStringMap => m.asCaseSensitiveMap() + case other => other + } + lazy val fileIndex: PartitioningAwareFileIndex = { - val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val caseSensitiveMap = caseSensitiveOptions.asScala.toMap // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) if (FileStreamSink.hasMetadata(paths, hadoopConf, sparkSession.sessionState.conf)) { @@ -104,7 +109,7 @@ abstract class FileTable( override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms - override def properties: util.Map[String, String] = options.asCaseSensitiveMap + override def properties: util.Map[String, String] = caseSensitiveOptions override def capabilities: java.util.Set[TableCapability] = FileTable.CAPABILITIES diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index dffb9cb67b5c2..e71aea93f263f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -42,7 +42,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import V2SessionCatalog._ - override val defaultNamespace: Array[String] = Array("default") + override def defaultNamespace: Array[String] = Array("default") override def name: String = CatalogManager.SESSION_CATALOG_NAME @@ -61,6 +61,14 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) } } + override def tableExists(ident: Identifier): Boolean = { + if (ident.namespace().length <= 1) { + catalog.tableExists(ident.asTableIdentifier) + } else { + false + } + } + override def loadTable(ident: Identifier): Table = { val catalogTable = try { catalog.getTableMetadata(ident.asTableIdentifier) @@ -69,6 +77,10 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) throw new NoSuchTableException(ident) } + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw new NoSuchTableException(ident) + } + V1Table(catalogTable) } @@ -109,7 +121,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) throw new TableAlreadyExistsException(ident) } - loadTable(ident) + V1Table(tableDesc) } override def alterTable( @@ -124,31 +136,27 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) val properties = CatalogV2Util.applyPropertiesChanges(catalogTable.properties, changes) val schema = CatalogV2Util.applySchemaChanges(catalogTable.schema, changes) + val updatedTable = catalogTable.copy(properties = properties, schema = schema) try { - catalog.alterTable(catalogTable.copy(properties = properties, schema = schema)) + catalog.alterTable(updatedTable) } catch { case _: NoSuchTableException => throw new NoSuchTableException(ident) } - loadTable(ident) + V1Table(updatedTable) } override def dropTable(ident: Identifier): Boolean = { - try { - if (loadTable(ident) != null) { - catalog.dropTable( - ident.asTableIdentifier, - ignoreIfNotExists = true, - purge = true /* skip HDFS trash */) - true - } else { - false - } - } catch { - case _: NoSuchTableException => - false + if (tableExists(ident)) { + catalog.dropTable( + ident.asTableIdentifier, + ignoreIfNotExists = true, + purge = true /* skip HDFS trash */) + true + } else { + false } } @@ -156,9 +164,9 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) if (tableExists(newIdent)) { throw new TableAlreadyExistsException(newIdent) } - - // Load table to make sure the table exists - loadTable(oldIdent) + if (!tableExists(oldIdent)) { + throw new NoSuchTableException(oldIdent) + } catalog.renameTable(oldIdent.asTableIdentifier, newIdent.asTableIdentifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index 1f99d4282f6da..bd6ae5b94638b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2.csv +import java.util + import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class CSVDataSourceV2 extends FileDataSourceV2 { @@ -29,15 +31,18 @@ class CSVDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "csv" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + CSVTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + CSVTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index 04beee0e3b0f2..a067b6a370419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class CSVTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala index 7a0949e586cd8..d8ba1e16e19cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2.json +import java.util + import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class JsonDataSourceV2 extends FileDataSourceV2 { @@ -29,16 +31,19 @@ class JsonDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "json" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + JsonTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + JsonTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala index 9bb615528fc5d..518a9242464ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class JsonTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 8665af33b976a..cd374cdfae96f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2.orc +import java.util + import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcDataSourceV2 extends FileDataSourceV2 { @@ -29,16 +31,18 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + OrcTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + OrcTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index f2e4b88e9f1ae..ebbfa4feaa376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala index 8cb6186c12ff3..5bac31f5ce2dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2.parquet +import java.util + import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class ParquetDataSourceV2 extends FileDataSourceV2 { @@ -29,16 +31,19 @@ class ParquetDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + ParquetTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + ParquetTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index 2ad64b1aa5244..5b8c3ac9d77bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala index 049c717effa26..d5545e5f93181 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala @@ -16,12 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2.text +import java.util + import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextDataSourceV2 extends FileDataSourceV2 { @@ -29,16 +31,18 @@ class TextDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "text" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def loadTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + TextTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + TextTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala index 87bfa84985e5a..c21c1441afff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class TextTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 20eb7ae5a6d96..e5ac910c32804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -39,7 +39,7 @@ class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { ConsoleTable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 911a526428cf4..7dc487a873415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -95,7 +95,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3f7b0377f1eab..f0a4561d278cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -46,9 +46,25 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends TableProvider with DataSourceRegister { - import RateStreamProvider._ + override def loadTable(properties: util.Map[String, String]): Table = { + RateStreamTable + } + + override def shortName(): String = "rate" +} + +object RateStreamTable extends Table with SupportsRead { + + override def name(): String = s"RateStreamTable" + + override def schema(): StructType = RateStreamProvider.SCHEMA - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + import RateStreamProvider._ val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( @@ -69,38 +85,16 @@ class RateStreamProvider extends TableProvider with DataSourceRegister { throw new IllegalArgumentException( s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") } - new RateStreamTable(rowsPerSecond, rampUpTimeSeconds, numPartitions) - } + () => new Scan { + override def readSchema(): StructType = RateStreamProvider.SCHEMA - override def shortName(): String = "rate" -} - -class RateStreamTable( - rowsPerSecond: Long, - rampUpTimeSeconds: Long, - numPartitions: Int) - extends Table with SupportsRead { - - override def name(): String = { - s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=$numPartitions)" - } + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = + new RateStreamMicroBatchStream( + rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) - override def schema(): StructType = RateStreamProvider.SCHEMA - - override def capabilities(): util.Set[TableCapability] = { - Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava - } - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan { - override def readSchema(): StructType = RateStreamProvider.SCHEMA - - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = - new RateStreamMicroBatchStream( - rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) - - override def toContinuousStream(checkpointLocation: String): ContinuousStream = - new RateStreamContinuousStream(rowsPerSecond, numPartitions) + override def toContinuousStream(checkpointLocation: String): ContinuousStream = + new RateStreamContinuousStream(rowsPerSecond, numPartitions) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index fae3cb765c0c9..30bd766e92adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream @@ -36,7 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: CaseInsensitiveStringMap): Unit = { + private def checkParameters(params: util.Map[String, String]): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") if (!params.containsKey("host")) { @@ -46,7 +47,7 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - params.getBoolean("includeTimestamp", false) + Option(params.get("includeTimestamp")).foreach(_.toBoolean) } match { case Success(_) => case Failure(_) => @@ -54,13 +55,22 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } } - override def getTable(options: CaseInsensitiveStringMap): Table = { - checkParameters(options) + override def loadTable(properties: util.Map[String, String]): Table = { + checkParameters(properties) new TextSocketTable( - options.get("host"), - options.getInt("port", -1), - options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), - options.getBoolean("includeTimestamp", false)) + properties.get("host"), + properties.get("port").toInt, + Option(properties.get("numPartitions")).map(_.toInt) + .getOrElse(SparkSession.active.sparkContext.defaultParallelism), + Option(properties.get("includeTimestamp")).map(_.toBoolean).getOrElse(false)) + } + + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "TextSocketSourceProvider source does not support user-specified schema") } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4a6516d325ddd..ac0ce337a47b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -188,8 +188,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + case Some(schema) => + DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, source, schema, dsOptions) + case _ => + provider.loadTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 74170b1b5d77e..2f8fe068699f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -308,7 +308,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { + provider.loadTable(dsOptions) match { case table: SupportsWrite if table.supports(STREAMING_WRITE) => table case _ => createV1Sink() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f0..ccbb67a420c40 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements TableProvider { @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 76da45e182b3c..b767512985628 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql.connector; import java.io.IOException; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Table; @@ -32,7 +33,6 @@ import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; - public class JavaColumnarDataSourceV2 implements TableProvider { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -52,7 +52,7 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index fbbc457b2945d..d1ee9c181f052 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -56,7 +57,7 @@ public Partitioning outputPartitioning() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(Map properties) { return new JavaSimpleBatchTable() { @Override public Transform[] partitioning() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index 49438fe668d56..6e294fd44bcfe 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -54,7 +54,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(java.util.Map options) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index 2181887ae54e2..01b94764fef23 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -17,8 +17,11 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; + import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.types.StructType; @@ -46,7 +49,10 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options, StructType schema) { + public Table loadTable( + StructType schema, + Transform[] partitions, + Map properties) { return new JavaSimpleBatchTable() { @Override @@ -62,7 +68,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(Map options) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 8b6d71b986ff7..498732bc4bc0d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -17,6 +17,8 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; + import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; @@ -37,7 +39,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table loadTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 207ece83759ed..4503b2228c68f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -92,7 +92,7 @@ class DataSourceV2DataFrameSessionCatalogSuite } class InMemoryTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException("D'oh!") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d353e6b3f56d8..abe2d9c6d07cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1162,7 +1162,7 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: java.util.Map[String, String]): Table = { throw new UnsupportedOperationException("Unnecessary for DDL tests") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala new file mode 100644 index 0000000000000..b2a6fd6f3ddd7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType + +class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { + + override def beforeEach(): Unit = { + super.beforeEach() + ReadWriteV2Source.table.clear() + PartitionedV2Source.table.clear() + } + + test("basic") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + + val e = intercept[AnalysisException](sql("INSERT INTO t SELECT 1")) + assert(e.message.contains("not enough data columns")) + + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("CREATE TABLE with a mismatched schema") { + withTable("t") { + val e = intercept[AnalysisException]( + sql(s"CREATE TABLE t(i INT) USING ${classOf[ReadWriteV2Source].getName}") + ) + assert(e.getMessage.contains("returns a table which has inappropriate schema")) + + sql(s"CREATE TABLE t(i INT, j INT) USING ${classOf[ReadWriteV2Source].getName}") + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("CREATE TABLE with a mismatched partitioning") { + withTable("t") { + val e = intercept[AnalysisException]( + sql( + s""" + |CREATE TABLE t(i INT, j INT) USING ${classOf[PartitionedV2Source].getName} + |PARTITIONED BY (i) + """.stripMargin) + ) + assert(e.getMessage.contains("returns a table which has inappropriate partitioning")) + + sql( + s""" + |CREATE TABLE t(i INT, j INT) USING ${classOf[PartitionedV2Source].getName} + |PARTITIONED BY (j) + """.stripMargin) + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("read-only table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadOnlyV2Source].getName}") + checkAnswer(spark.table("t"), Seq(Row(1), Row(-1))) + + val e1 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1, -1")) + assert(e1.message.contains("too many data columns")) + + val e2 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1")) + assert(e2.message.contains("does not support append in batch mode")) + } + } + + test("write-only table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[WriteOnlyV2Source].getName}") + + val e1 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1, 1")) + assert(e1.message.contains("too many data columns")) + + sql("INSERT INTO t SELECT 1") + + val e2 = intercept[AnalysisException](sql("SELECT * FROM t").collect()) + assert(e2.message.contains("Table write-only does not support batch scan")) + } + } + + test("CREATE TABLE AS SELECT") { + withTable("t") { + sql( + s""" + |CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName} + |AS SELECT 1 AS i, -1 AS j + """.stripMargin) + checkAnswer(spark.table("t"), Row(1, -1)) + + sql("INSERT INTO t SELECT 2, -2") + checkAnswer(spark.table("t"), Seq(Row(1, -1), Row(2, -2))) + } + } + + test("INSERT OVERWRITE with non-partitioned table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + + sql("INSERT OVERWRITE t SELECT 2, -2") + checkAnswer(spark.table("t"), Row(2, -2)) + } + } + + test("INSERT OVERWRITE with partitioned table (static mode)") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "static") { + sql("INSERT OVERWRITE t PARTITION(j=2) SELECT 4") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(4, 2))) + + sql("INSERT OVERWRITE t SELECT 0, 1") + checkAnswer(spark.table("t"), Row(0, 1)) + } + } + } + + test("INSERT OVERWRITE with partitioned table (dynamic mode)") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { + sql("INSERT OVERWRITE t PARTITION(j=2) SELECT 4") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(4, 2))) + + sql("INSERT OVERWRITE t SELECT 0, 1") + checkAnswer(spark.table("t"), Seq(Row(0, 1), Row(4, 2))) + } + } + } + + // TODO: enable it when DELETE FROM supports v2 session catalog. + ignore("DELETE FROM") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + sql("DELETE FROM t WHERE j = 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1))) + } + } + + test("ALTER TABLE") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + + sql("ALTER TABLE t DROP COLUMN i") + val e = intercept[AnalysisException](sql("SELECT * FROM t")) + assert(e.message.contains("returns a table which has inappropriate schema")) + } + } + + test("ALTER TABLE with table supporting schema changing") { + withTable("t") { + sql(s"CREATE TABLE t(i INT) USING ${classOf[DynamicSchemaV2Source].getName}") + checkAnswer(spark.table("t"), Nil) + + sql("ALTER TABLE t ADD COLUMN j INT") + checkAnswer(spark.table("t"), Nil) + } + } +} + +object ReadWriteV2Source { + val table = { + val schema = new StructType().add("i", "int").add("j", "int") + val partitioning = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("read-write", schema, partitioning, properties) {} + } +} + +class ReadWriteV2Source extends TableProvider { + // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data + // of in-memory table, we keep the table instance in an object. + override def loadTable(properties: util.Map[String, String]): Table = { + ReadWriteV2Source.table + } +} + +object PartitionedV2Source { + val table = { + val schema = new StructType().add("i", "int").add("j", "int") + val partitioning = Array[Transform](LogicalExpressions.identity("j")) + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("read-write", schema, partitioning, properties) {} + } +} + +class PartitionedV2Source extends TableProvider { + // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data + // of in-memory table, we keep the table instance in an object. + override def loadTable(properties: util.Map[String, String]): Table = { + PartitionedV2Source.table + } +} + +class ReadOnlyV2Source extends TableProvider { + override def loadTable(properties: util.Map[String, String]): Table = { + val schema = new StructType().add("i", "int") + val partitions = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + val table = new InMemoryTable("read-only", schema, partitions, properties) { + override def capabilities: util.Set[TableCapability] = { + Set(TableCapability.BATCH_READ).asJava + } + } + val rows = new BufferedRows() + rows.withRow(InternalRow(1)).withRow(InternalRow(-1)) + table.withData(Array(rows)) + } +} + +class WriteOnlyV2Source extends TableProvider { + override def loadTable(properties: util.Map[String, String]): Table = { + val schema = new StructType().add("i", "int") + val partitions = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("write-only", schema, partitions, properties) { + override def capabilities: util.Set[TableCapability] = { + Set(TableCapability.BATCH_WRITE).asJava + } + } + } +} + +class DynamicSchemaV2Source extends TableProvider { + override def loadTable(properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", new StructType(), Array.empty, properties) {} + } + + override def loadTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", schema, partitions, properties) {} + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 138bbc3f04f64..f7ffef7d3a49c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} @@ -444,7 +445,7 @@ class SimpleSinglePartitionSource extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -461,7 +462,7 @@ class SimpleDataSourceV2 extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -470,7 +471,7 @@ class SimpleDataSourceV2 extends TableProvider { class AdvancedDataSourceV2 extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } @@ -566,11 +567,14 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(options: util.Map[String, String]): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + override def loadTable( + schema: StructType, + partitions: Array[Transform], + options: util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema @@ -595,7 +599,7 @@ class ColumnarDataSourceV2 extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -666,7 +670,7 @@ class PartitionAwareDataSource extends TableProvider { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -706,7 +710,7 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(options: util.Map[String, String]): Table = { new MyTable(options) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") @@ -732,7 +736,7 @@ class ReportStatisticsDataSource extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(options: util.Map[String, String]): Table = { new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 2b3340527a4e2..2a27a039acc95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -40,7 +40,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: java.util.Map[String, String]): Table = { new DummyReadOnlyFileTable } } @@ -64,7 +64,7 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: java.util.Map[String, String]): Table = { new DummyWriteOnlyFileTable } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 22d3750022c57..9f1c2f72775ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -131,8 +131,7 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } } - class MyTable(options: CaseInsensitiveStringMap) - extends SimpleBatchTable with SupportsWrite { + class MyTable(options: util.Map[String, String]) extends SimpleBatchTable with SupportsWrite { private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration @@ -151,7 +150,7 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(options: util.Map[String, String]): Table = { new MyTable(options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce6d56cf84df1..7f7755d3151d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -204,7 +204,7 @@ private case object TestRelation extends LeafNode with NamedRelation { } private object TestTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index de843ba4375d0..e0ee350f8423e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -176,14 +176,14 @@ class InMemoryV1Provider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - InMemoryV1Provider.tables.getOrElse(options.get("name"), { + override def loadTable(properties: util.Map[String, String]): Table = { + InMemoryV1Provider.tables.getOrElse(properties.get("name"), { new InMemoryTableWithV1Fallback( "InMemoryTableWithV1Fallback", new StructType(), Array.empty, - options.asCaseSensitiveMap() + properties ) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index f791ab66e86fa..d300d47125acc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -175,13 +175,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.getTable(CaseInsensitiveStringMap.empty()) + provider.loadTable(CaseInsensitiveStringMap.empty()) } intercept[AnalysisException] { - provider.getTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) + provider.loadTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.getTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) + provider.loadTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) } } @@ -189,7 +189,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { val provider = new TextSocketSourceProvider val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { - provider.getTable(new CaseInsensitiveStringMap(params.asJava)) + provider.loadTable(new CaseInsensitiveStringMap(params.asJava)) } } @@ -198,9 +198,8 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { val userSpecifiedSchema = StructType( StructField("name", StringType) :: StructField("area", StringType) :: Nil) - val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) + provider.loadTable(userSpecifiedSchema, Array.empty, CaseInsensitiveStringMap.empty()) } assert(exception.getMessage.contains( "TextSocketSourceProvider source does not support user-specified schema")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index e9d148c38e6cb..6e956ad20064f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -86,6 +86,7 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { Set(STREAMING_WRITE).asJava } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + LastWriteOptions.options = options new FakeWriteBuilder } } @@ -98,8 +99,7 @@ class FakeReadMicroBatchOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def loadTable(properties: util.Map[String, String]): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -107,6 +107,7 @@ class FakeReadMicroBatchOnly Set(MICRO_BATCH_READ).asJava } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + LastReadOptions.options = options new FakeScanBuilder } } @@ -121,8 +122,7 @@ class FakeReadContinuousOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def loadTable(properties: util.Map[String, String]): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -130,6 +130,7 @@ class FakeReadContinuousOnly Set(CONTINUOUS_READ).asJava } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + LastReadOptions.options = options new FakeScanBuilder } } @@ -139,7 +140,7 @@ class FakeReadContinuousOnly class FakeReadBothModes extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -156,7 +157,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { class FakeReadNeitherMode extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -173,8 +174,7 @@ class FakeWriteOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastWriteOptions.options = options + override def loadTable(properties: util.Map[String, String]): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -184,7 +184,7 @@ class FakeWriteOnly class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -212,7 +212,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def loadTable(properties: util.Map[String, String]): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -377,10 +377,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[TableProvider].loadTable(CaseInsensitiveStringMap.empty()) val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[TableProvider].loadTable(CaseInsensitiveStringMap.empty()) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ trigger match { From 6731b528360ac05908237a75ac2676c2b46d3a5e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 25 Sep 2019 22:12:12 +0800 Subject: [PATCH 2/5] make the change smaller --- .../spark/sql/v2/avro/AvroDataSourceV2.scala | 9 +-- .../sql/kafka010/KafkaSourceProvider.scala | 4 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 2 +- .../kafka010/KafkaSourceProviderSuite.scala | 2 +- .../SupportsSpecifiedSchemaPartitioning.java | 49 ++++++++++++++++ .../sql/connector/catalog/TableProvider.java | 34 ++--------- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/noop/NoopDataSource.scala | 2 +- .../v2/CatalogExtensionForTableProvider.scala | 31 +++++----- .../datasources/v2/DataSourceV2Utils.scala | 22 +++---- .../datasources/v2/FileDataSourceV2.scala | 6 +- .../datasources/v2/csv/CSVDataSourceV2.scala | 9 +-- .../v2/json/JsonDataSourceV2.scala | 9 +-- .../datasources/v2/orc/OrcDataSourceV2.scala | 9 +-- .../v2/parquet/ParquetDataSourceV2.scala | 9 +-- .../v2/text/TextDataSourceV2.scala | 9 +-- .../sql/execution/streaming/console.scala | 2 +- .../sql/execution/streaming/memory.scala | 2 +- .../sources/RateStreamProvider.scala | 58 ++++++++++--------- .../sources/TextSocketSourceProvider.scala | 26 +++------ .../sql/streaming/DataStreamReader.scala | 4 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../connector/JavaAdvancedDataSourceV2.java | 2 +- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../JavaPartitionAwareDataSource.java | 3 +- .../JavaReportStatisticsDataSource.java | 2 +- .../JavaSchemaRequiredDataSource.java | 8 ++- .../sql/connector/JavaSimpleDataSourceV2.java | 4 +- ...SourceV2DataFrameSessionCatalogSuite.scala | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- .../connector/DataSourceV2SQLUsingSuite.scala | 19 +++--- .../sql/connector/DataSourceV2Suite.scala | 22 +++---- .../FileDataSourceV2FallBackSuite.scala | 21 ++++++- .../connector/SimpleWritableDataSource.scala | 5 +- .../connector/TableCapabilityCheckSuite.scala | 2 +- .../sql/connector/V1WriteFallbackSuite.scala | 6 +- .../sources/TextSocketStreamSuite.scala | 20 ++----- .../sources/StreamingDataSourceV2Suite.scala | 24 ++++---- 39 files changed, 245 insertions(+), 207 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index 98fe9e10c1afb..e45f750be329b 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class AvroDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class AvroDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "avro" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index ad22cd9965755..c15f08d78741d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -108,8 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParameters)) } - override def loadTable(properties: java.util.Map[String, String]): KafkaTable = { - val includeHeaders = Option(properties.get(INCLUDE_HEADERS)).map(_.toBoolean).getOrElse(false) + override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { + val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) new KafkaTable(includeHeaders) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 9867f60b50be7..3ee59e57a6edf 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1142,7 +1142,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} val dsOptions = new CaseInsensitiveStringMap(options.asJava) - val table = provider.loadTable(dsOptions) + val table = provider.getTable(dsOptions) val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) val inputPartitions = stream.planInputPartitions( KafkaSourceOffset(Map(tp -> 0L)), diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala index 984e100b562f0..f7b00b31ebba0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceProviderSuite.scala @@ -125,6 +125,6 @@ class KafkaSourceProviderSuite extends SparkFunSuite { private def getKafkaDataSourceScan(options: CaseInsensitiveStringMap): Scan = { val provider = new KafkaSourceProvider() - provider.loadTable(options).newScanBuilder(options).build() + provider.getTable(options).newScanBuilder(options).build() } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java new file mode 100644 index 0000000000000..308ef8a52c5e4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import java.util.Map; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to + * return a table instance with specified schema and partitioning, so that they may avoid expensive + * schema/partitioning inference. + */ +@Evolving +public interface SupportsSpecifiedSchemaPartitioning extends TableProvider { + + /** + * Return a {@link Table} instance with the given table schema, partitioning and properties to do + * read/write . The returned table must report the same schema and partitioning with the given + * ones. + * + * @param schema The schema of the table to load. + * @param partitions The data partitioning of the table to load. + * @param properties The properties of the table to load. It should be sufficient to define and + * access a table. The properties map may be {@link CaseInsensitiveStringMap}. + */ + Table getTable( + StructType schema, + Transform[] partitions, + Map properties); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index c3eef90e3b32c..2cf76761805d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -17,11 +17,7 @@ package org.apache.spark.sql.connector.catalog; -import java.util.Map; - import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -39,32 +35,12 @@ public interface TableProvider { /** - * Return a {@link Table} instance with the given table properties to do read/write. + * Return a {@link Table} instance with the given table options to do read/write. * Implementations should infer the table schema and partitioning. * - * @param properties The properties of the table to load. It should be sufficient to define and - * access a table. The properties map may be {@link CaseInsensitiveStringMap}. - */ - Table loadTable(Map properties); - - /** - * Return a {@link Table} instance with the given table schema, partitioning and properties to do - * read/write . The returned table must report the same schema and partitioning with the given - * ones. - * - * By default this method simply calls {@link #loadTable(Map)}. The returned table may report - * different schema/partitioning and fail the job later. Implementation should override - * this method if they can leverage the given schema and partitioning. - * - * @param schema The schema of the table to load. - * @param partitions The data partitioning of the table to load. - * @param properties The properties of the table to load. It should be sufficient to define and - * access a table. The properties map may be {@link CaseInsensitiveStringMap}. + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. */ - default Table loadTable( - StructType schema, - Transform[] partitions, - Map properties) { - return loadTable(properties); - } + // TODO: this should take a Map as table properties. + Table getTable(CaseInsensitiveStringMap options); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a5c75eff84c55..10638df9ca52b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -217,9 +217,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = userSpecifiedSchema match { case Some(schema) => - DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, source, schema, dsOptions) + DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, schema, dsOptions) case _ => - provider.loadTable(dsOptions) + provider.getTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index fa86af72a262a..3f7016df2eb42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -260,7 +260,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.loadTable(dsOptions) match { + provider.getTable(dsOptions) match { case table: SupportsWrite if table.supports(BATCH_WRITE) => if (partitioningColumns.nonEmpty) { throw new AnalysisException("Cannot write data to TableProvider implementation " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 0cba66efb3783..3f4f29c3e135a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def loadTable(properties: util.Map[String, String]): Table = NoopTable + override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } private[noop] object NoopTable extends Table with SupportsWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala index 4898575691480..7f38d3ff34d7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala @@ -22,11 +22,12 @@ import java.util import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{DelegatingCatalogExtension, Identifier, Table} +import org.apache.spark.sql.connector.catalog.{DelegatingCatalogExtension, Identifier, SupportsSpecifiedSchemaPartitioning, Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class CatalogExtensionForTableProvider extends DelegatingCatalogExtension { @@ -49,7 +50,7 @@ class CatalogExtensionForTableProvider extends DelegatingCatalogExtension { assert(partitions.isEmpty) // If `CREATE TABLE ... USING` does not specify table metadata, get the table metadata from // data source first. - val table = maybeProvider.get.loadTable(properties) + val table = maybeProvider.get.getTable(new CaseInsensitiveStringMap(properties)) table.schema() -> table.partitioning() } else { schema -> partitions @@ -70,26 +71,28 @@ class CatalogExtensionForTableProvider extends DelegatingCatalogExtension { private def tryResolveTableProvider(table: Table): Table = { val providerName = table.properties().get("provider") assert(providerName != null) - DataSource.lookupDataSourceV2(providerName, conf).map { provider => + DataSource.lookupDataSourceV2(providerName, conf).map { // TODO: support file source v2 in CREATE TABLE USING. - if (provider.isInstanceOf[FileDataSourceV2]) { - table - } else { - val loaded = provider.loadTable(table.schema, table.partitioning, table.properties) - if (loaded.schema().asNullable != table.schema.asNullable) { + case _: FileDataSourceV2 => table + + case s: SupportsSpecifiedSchemaPartitioning => + s.getTable(table.schema, table.partitioning, table.properties) + + case provider => + val actualTable = provider.getTable(new CaseInsensitiveStringMap(table.properties)) + if (actualTable.schema() != table.schema) { throw new AnalysisException(s"Table provider '$providerName' returns a table " + "which has inappropriate schema:\n" + s"schema in Spark meta-store:\t${table.schema}\n" + - s"schema from table provider:\t${loaded.schema()}") + s"schema from table provider:\t${actualTable.schema}") } - if (!loaded.partitioning().sameElements(table.partitioning)) { + if (!actualTable.partitioning.sameElements(table.partitioning)) { throw new AnalysisException(s"Table provider '$providerName' returns a table " + "which has inappropriate partitioning:\n" + - s"partitioning in Spark meta-store:\t${table.partitioning.mkString(", ")}\n" + - s"partitioning from table provider:\t${loaded.partitioning.mkString(", ")}") + s"partitioning in Spark meta-store: ${table.partitioning.mkString(", ")}\n" + + s"partitioning from table provider: ${actualTable.partitioning.mkString(", ")}") } - loaded - } + actualTable }.getOrElse(table) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 13d921dff0e8b..4b278ce2ea719 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,7 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsSpecifiedSchemaPartitioning, Table, TableProvider} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -63,18 +63,18 @@ private[sql] object DataSourceV2Utils extends Logging { def loadTableWithUserSpecifiedSchema( provider: TableProvider, - providerName: String, schema: StructType, options: CaseInsensitiveStringMap): Table = { - // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified - // partitioning. - val table = provider.loadTable(schema, Array.empty, options) - if (table.schema().asNullable != schema.asNullable) { - throw new AnalysisException(s"Table provider '$providerName' returns a table which " + - "has inappropriate schema:\n" + - s"user-specified schema:\t$schema\n" + - s"schema from table provider:\t${table.schema()}") + provider match { + case s: SupportsSpecifiedSchemaPartitioning => + // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified + // partitioning. + s.getTable(schema, Array.empty, options) + + case _ => + throw new UnsupportedOperationException( + provider.getClass.getSimpleName + " source does not support user-specified schema"); + } - table } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index 8eacd6df0c371..659f1af8452a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -20,7 +20,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableProvider +import org.apache.spark.sql.connector.catalog.{SupportsSpecifiedSchemaPartitioning, TableProvider} import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.util.Utils @@ -28,7 +28,9 @@ import org.apache.spark.util.Utils /** * A base interface for data source v2 implementations of the built-in file-based data sources. */ -trait FileDataSourceV2 extends TableProvider with DataSourceRegister { +trait FileDataSourceV2 extends TableProvider + with SupportsSpecifiedSchemaPartitioning with DataSourceRegister { + /** * Returns a V1 [[FileFormat]] class of the same file data source. * This is a solution for the following cases: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index bd6ae5b94638b..f826869a2d982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class CSVDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class CSVDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "csv" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala index d8ba1e16e19cd..d5d35f77a6ebb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class JsonDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class JsonDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "json" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index cd374cdfae96f..5e2ffdc43ca8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala index 5bac31f5ce2dd..b3bc5eb575997 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class ParquetDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class ParquetDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala index d5545e5f93181..fe15e55351686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextDataSourceV2 extends FileDataSourceV2 { @@ -31,13 +32,13 @@ class TextDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "text" - override def loadTable(properties: util.Map[String, String]): Table = { - val paths = getPaths(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) + TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index e5ac910c32804..20eb7ae5a6d96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -39,7 +39,7 @@ class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { ConsoleTable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7dc487a873415..911a526428cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -95,7 +95,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index f0a4561d278cd..3f7b0377f1eab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -46,25 +46,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends TableProvider with DataSourceRegister { - override def loadTable(properties: util.Map[String, String]): Table = { - RateStreamTable - } - - override def shortName(): String = "rate" -} - -object RateStreamTable extends Table with SupportsRead { - - override def name(): String = s"RateStreamTable" - - override def schema(): StructType = RateStreamProvider.SCHEMA + import RateStreamProvider._ - override def capabilities(): util.Set[TableCapability] = { - Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava - } - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - import RateStreamProvider._ + override def getTable(options: CaseInsensitiveStringMap): Table = { val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( @@ -85,16 +69,38 @@ object RateStreamTable extends Table with SupportsRead { throw new IllegalArgumentException( s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") } - () => new Scan { - override def readSchema(): StructType = RateStreamProvider.SCHEMA + new RateStreamTable(rowsPerSecond, rampUpTimeSeconds, numPartitions) + } - override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = - new RateStreamMicroBatchStream( - rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) + override def shortName(): String = "rate" +} - override def toContinuousStream(checkpointLocation: String): ContinuousStream = - new RateStreamContinuousStream(rowsPerSecond, numPartitions) - } +class RateStreamTable( + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int) + extends Table with SupportsRead { + + override def name(): String = { + s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=$numPartitions)" + } + + override def schema(): StructType = RateStreamProvider.SCHEMA + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = () => new Scan { + override def readSchema(): StructType = RateStreamProvider.SCHEMA + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = + new RateStreamMicroBatchStream( + rowsPerSecond, rampUpTimeSeconds, numPartitions, options, checkpointLocation) + + override def toContinuousStream(checkpointLocation: String): ContinuousStream = + new RateStreamContinuousStream(rowsPerSecond, numPartitions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index 30bd766e92adb..fae3cb765c0c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -27,7 +27,6 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream @@ -37,7 +36,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: util.Map[String, String]): Unit = { + private def checkParameters(params: CaseInsensitiveStringMap): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") if (!params.containsKey("host")) { @@ -47,7 +46,7 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - Option(params.get("includeTimestamp")).foreach(_.toBoolean) + params.getBoolean("includeTimestamp", false) } match { case Success(_) => case Failure(_) => @@ -55,22 +54,13 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } } - override def loadTable(properties: util.Map[String, String]): Table = { - checkParameters(properties) + override def getTable(options: CaseInsensitiveStringMap): Table = { + checkParameters(options) new TextSocketTable( - properties.get("host"), - properties.get("port").toInt, - Option(properties.get("numPartitions")).map(_.toInt) - .getOrElse(SparkSession.active.sparkContext.defaultParallelism), - Option(properties.get("includeTimestamp")).map(_.toBoolean).getOrElse(false)) - } - - override def loadTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { - throw new UnsupportedOperationException( - "TextSocketSourceProvider source does not support user-specified schema") + options.get("host"), + options.getInt("port", -1), + options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), + options.getBoolean("includeTimestamp", false)) } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ac0ce337a47b0..65e98b2e833fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -189,9 +189,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = userSpecifiedSchema match { case Some(schema) => - DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, source, schema, dsOptions) + DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, schema, dsOptions) case _ => - provider.loadTable(dsOptions) + provider.getTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2f8fe068699f3..74170b1b5d77e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -308,7 +308,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.loadTable(dsOptions) match { + provider.getTable(dsOptions) match { case table: SupportsWrite if table.supports(STREAMING_WRITE) => table case _ => createV1Sink() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index ccbb67a420c40..9386ab51d64f0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 implements TableProvider { @Override - public Table loadTable(Map properties) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index b767512985628..76da45e182b3c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -18,7 +18,6 @@ package test.org.apache.spark.sql.connector; import java.io.IOException; -import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.Table; @@ -33,6 +32,7 @@ import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; + public class JavaColumnarDataSourceV2 implements TableProvider { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -52,7 +52,7 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table loadTable(Map properties) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index d1ee9c181f052..fbbc457b2945d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -57,7 +56,7 @@ public Partitioning outputPartitioning() { } @Override - public Table loadTable(Map properties) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override public Transform[] partitioning() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index 6e294fd44bcfe..49438fe668d56 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -54,7 +54,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table loadTable(java.util.Map options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index 01b94764fef23..a9ac9e77edf86 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -19,6 +19,7 @@ import java.util.Map; +import org.apache.spark.sql.connector.catalog.SupportsSpecifiedSchemaPartitioning; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; @@ -27,7 +28,8 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaSchemaRequiredDataSource implements TableProvider { +public class JavaSchemaRequiredDataSource + implements TableProvider, SupportsSpecifiedSchemaPartitioning { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -49,7 +51,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table loadTable( + public Table getTable( StructType schema, Transform[] partitions, Map properties) { @@ -68,7 +70,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { } @Override - public Table loadTable(Map options) { + public Table getTable(CaseInsensitiveStringMap options) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 498732bc4bc0d..8b6d71b986ff7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -17,8 +17,6 @@ package test.org.apache.spark.sql.connector; -import java.util.Map; - import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; @@ -39,7 +37,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table loadTable(Map properties) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 4503b2228c68f..207ece83759ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -92,7 +92,7 @@ class DataSourceV2DataFrameSessionCatalogSuite } class InMemoryTableProvider extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new UnsupportedOperationException("D'oh!") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index abe2d9c6d07cd..d353e6b3f56d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1162,7 +1162,7 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends TableProvider { - override def loadTable(properties: java.util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new UnsupportedOperationException("Unnecessary for DDL tests") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala index b2a6fd6f3ddd7..c40de492e137e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala @@ -23,11 +23,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsSpecifiedSchemaPartitioning, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { @@ -221,7 +222,7 @@ object ReadWriteV2Source { class ReadWriteV2Source extends TableProvider { // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data // of in-memory table, we keep the table instance in an object. - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { ReadWriteV2Source.table } } @@ -238,13 +239,13 @@ object PartitionedV2Source { class PartitionedV2Source extends TableProvider { // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data // of in-memory table, we keep the table instance in an object. - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { PartitionedV2Source.table } } class ReadOnlyV2Source extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { val schema = new StructType().add("i", "int") val partitions = Array.empty[Transform] val properties = util.Collections.emptyMap[String, String] @@ -260,7 +261,7 @@ class ReadOnlyV2Source extends TableProvider { } class WriteOnlyV2Source extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { val schema = new StructType().add("i", "int") val partitions = Array.empty[Transform] val properties = util.Collections.emptyMap[String, String] @@ -272,12 +273,12 @@ class WriteOnlyV2Source extends TableProvider { } } -class DynamicSchemaV2Source extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { - new InMemoryTable("dynamic-schema", new StructType(), Array.empty, properties) {} +class DynamicSchemaV2Source extends TableProvider with SupportsSpecifiedSchemaPartitioning { + override def getTable(options: CaseInsensitiveStringMap): Table = { + new InMemoryTable("dynamic-schema", new StructType(), Array.empty, options) {} } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index f7ffef7d3a49c..17befa47f0bd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -28,7 +28,7 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsSpecifiedSchemaPartitioning, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read._ @@ -445,7 +445,7 @@ class SimpleSinglePartitionSource extends TableProvider { } } - override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -462,7 +462,7 @@ class SimpleDataSourceV2 extends TableProvider { } } - override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -471,7 +471,7 @@ class SimpleDataSourceV2 extends TableProvider { class AdvancedDataSourceV2 extends TableProvider { - override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } @@ -559,7 +559,7 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } -class SchemaRequiredDataSource extends TableProvider { +class SchemaRequiredDataSource extends TableProvider with SupportsSpecifiedSchemaPartitioning { class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = Array.empty @@ -567,11 +567,11 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def loadTable(options: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def loadTable( + override def getTable( schema: StructType, partitions: Array[Transform], options: util.Map[String, String]): Table = { @@ -599,7 +599,7 @@ class ColumnarDataSourceV2 extends TableProvider { } } - override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -670,7 +670,7 @@ class PartitionAwareDataSource extends TableProvider { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def loadTable(options: util.Map[String, String]): Table = new SimpleBatchTable { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -710,7 +710,7 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def loadTable(options: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") @@ -736,7 +736,7 @@ class ReportStatisticsDataSource extends TableProvider { } } - override def loadTable(options: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 2a27a039acc95..079853a653193 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.connector +import java.util + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.WriteBuilder import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution} @@ -40,9 +43,16 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def loadTable(properties: java.util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyReadOnlyFileTable } + + override def getTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException + } } class DummyReadOnlyFileTable extends Table with SupportsRead { @@ -64,9 +74,16 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def loadTable(properties: java.util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyWriteOnlyFileTable } + + override def getTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException + } } class DummyWriteOnlyFileTable extends Table with SupportsWrite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 9f1c2f72775ac..22d3750022c57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -131,7 +131,8 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } } - class MyTable(options: util.Map[String, String]) extends SimpleBatchTable with SupportsWrite { + class MyTable(options: CaseInsensitiveStringMap) + extends SimpleBatchTable with SupportsWrite { private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration @@ -150,7 +151,7 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } - override def loadTable(options: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index 7f7755d3151d1..ce6d56cf84df1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -204,7 +204,7 @@ private case object TestRelation extends LeafNode with NamedRelation { } private object TestTableProvider extends TableProvider { - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index e0ee350f8423e..de843ba4375d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -176,14 +176,14 @@ class InMemoryV1Provider extends TableProvider with DataSourceRegister with CreatableRelationProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { - override def loadTable(properties: util.Map[String, String]): Table = { - InMemoryV1Provider.tables.getOrElse(properties.get("name"), { + InMemoryV1Provider.tables.getOrElse(options.get("name"), { new InMemoryTableWithV1Fallback( "InMemoryTableWithV1Fallback", new StructType(), Array.empty, - properties + options.asCaseSensitiveMap() ) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index d300d47125acc..f0bf6d2b06120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -175,13 +175,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.loadTable(CaseInsensitiveStringMap.empty()) + provider.getTable(CaseInsensitiveStringMap.empty()) } intercept[AnalysisException] { - provider.loadTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.loadTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) } } @@ -189,22 +189,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { val provider = new TextSocketSourceProvider val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { - provider.loadTable(new CaseInsensitiveStringMap(params.asJava)) + provider.getTable(new CaseInsensitiveStringMap(params.asJava)) } } - test("user-specified schema given") { - val provider = new TextSocketSourceProvider - val userSpecifiedSchema = StructType( - StructField("name", StringType) :: - StructField("area", StringType) :: Nil) - val exception = intercept[UnsupportedOperationException] { - provider.loadTable(userSpecifiedSchema, Array.empty, CaseInsensitiveStringMap.empty()) - } - assert(exception.getMessage.contains( - "TextSocketSourceProvider source does not support user-specified schema")) - } - test("input row metrics") { serverThread = new ServerThread() serverThread.start() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 6e956ad20064f..e9d148c38e6cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -86,7 +86,6 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { Set(STREAMING_WRITE).asJava } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { - LastWriteOptions.options = options new FakeWriteBuilder } } @@ -99,7 +98,8 @@ class FakeReadMicroBatchOnly override def keyPrefix: String = shortName() - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { + LastReadOptions.options = options new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -107,7 +107,6 @@ class FakeReadMicroBatchOnly Set(MICRO_BATCH_READ).asJava } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - LastReadOptions.options = options new FakeScanBuilder } } @@ -122,7 +121,8 @@ class FakeReadContinuousOnly override def keyPrefix: String = shortName() - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { + LastReadOptions.options = options new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -130,7 +130,6 @@ class FakeReadContinuousOnly Set(CONTINUOUS_READ).asJava } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - LastReadOptions.options = options new FakeScanBuilder } } @@ -140,7 +139,7 @@ class FakeReadContinuousOnly class FakeReadBothModes extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -157,7 +156,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { class FakeReadNeitherMode extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-neither-mode" - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -174,7 +173,8 @@ class FakeWriteOnly override def keyPrefix: String = shortName() - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { + LastWriteOptions.options = options new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -184,7 +184,7 @@ class FakeWriteOnly class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -212,7 +212,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" - override def loadTable(properties: util.Map[String, String]): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -377,10 +377,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].loadTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].loadTable(CaseInsensitiveStringMap.empty()) + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ trigger match { From 1124b47a4cb943336b9d431d91121beb6702c699 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 9 Oct 2019 00:03:14 +0800 Subject: [PATCH 3/5] address comments --- .../spark/sql/v2/avro/AvroDataSourceV2.scala | 6 +- .../sql/kafka010/KafkaSourceProvider.scala | 14 +++ .../SupportsSpecifiedSchemaPartitioning.java | 49 ---------- .../sql/connector/catalog/TableProvider.java | 43 +++++++- .../connector/catalog/CatalogManager.scala | 19 ++-- .../spark/sql/connector/catalog/V1Table.scala | 24 ++--- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../datasources/noop/NoopDataSource.scala | 14 +++ .../v2/CatalogExtensionForTableProvider.scala | 98 ------------------- .../datasources/v2/DataSourceV2Utils.scala | 42 ++++++-- .../datasources/v2/FileDataSourceV2.scala | 17 +++- .../datasources/v2/V2SessionCatalog.scala | 50 ++++++++-- .../datasources/v2/csv/CSVDataSourceV2.scala | 6 +- .../v2/json/JsonDataSourceV2.scala | 6 +- .../datasources/v2/orc/OrcDataSourceV2.scala | 6 +- .../v2/parquet/ParquetDataSourceV2.scala | 6 +- .../v2/text/TextDataSourceV2.scala | 6 +- .../sql/execution/streaming/console.scala | 14 +++ .../sql/execution/streaming/memory.scala | 14 +++ .../sources/RateStreamProvider.scala | 14 +++ .../sources/TextSocketSourceProvider.scala | 14 +++ .../sql/streaming/DataStreamReader.scala | 8 +- .../connector/JavaAdvancedDataSourceV2.java | 4 +- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../JavaPartitionAwareDataSource.java | 4 +- .../JavaReportStatisticsDataSource.java | 4 +- .../JavaSchemaRequiredDataSource.java | 18 ++-- .../sql/connector/JavaSimpleBatchTable.java | 3 +- .../sql/connector/JavaSimpleDataSourceV2.java | 4 +- .../sql/connector/JavaSimpleScanBuilder.java | 3 +- ...SourceV2DataFrameSessionCatalogSuite.scala | 10 +- .../sql/connector/DataSourceV2SQLSuite.scala | 16 ++- .../connector/DataSourceV2SQLUsingSuite.scala | 58 +++++++++-- .../sql/connector/DataSourceV2Suite.scala | 52 +++++++--- .../FileDataSourceV2FallBackSuite.scala | 11 +-- .../connector/SimpleWritableDataSource.scala | 2 +- .../connector/TableCapabilityCheckSuite.scala | 8 +- .../sql/connector/V1WriteFallbackSuite.scala | 12 +++ .../command/PlanResolutionSuite.scala | 4 +- .../sources/TextSocketStreamSuite.scala | 12 +++ .../sources/StreamingDataSourceV2Suite.scala | 28 ++++-- 42 files changed, 427 insertions(+), 311 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index e45f750be329b..faf6245cafd62 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -20,7 +20,6 @@ import java.util import org.apache.spark.sql.avro.AvroFileFormat import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.types.StructType @@ -38,10 +37,7 @@ class AvroDataSourceV2 extends FileDataSourceV2 { AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) AvroTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index c15f08d78741d..01134efb427b1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.kafka010.KafkaConfigUpdater import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder} @@ -108,6 +109,19 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParameters)) } + override def getTable(schema: StructType, properties: ju.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Kafka source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: ju.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Kafka source does not support user-specified schema/partitioning.") + } + override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) new KafkaTable(includeHeaders) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java deleted file mode 100644 index 308ef8a52c5e4..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsSpecifiedSchemaPartitioning.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog; - -import java.util.Map; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to - * return a table instance with specified schema and partitioning, so that they may avoid expensive - * schema/partitioning inference. - */ -@Evolving -public interface SupportsSpecifiedSchemaPartitioning extends TableProvider { - - /** - * Return a {@link Table} instance with the given table schema, partitioning and properties to do - * read/write . The returned table must report the same schema and partitioning with the given - * ones. - * - * @param schema The schema of the table to load. - * @param partitions The data partitioning of the table to load. - * @param properties The properties of the table to load. It should be sufficient to define and - * access a table. The properties map may be {@link CaseInsensitiveStringMap}. - */ - Table getTable( - StructType schema, - Transform[] partitions, - Map properties); -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index 2cf76761805d9..c003fcd924f74 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -17,7 +17,11 @@ package org.apache.spark.sql.connector.catalog; +import java.util.Map; + import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -35,12 +39,45 @@ public interface TableProvider { /** - * Return a {@link Table} instance with the given table options to do read/write. + * Return a {@link Table} instance with the user-specified table properties to do read/write. * Implementations should infer the table schema and partitioning. * - * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param options The user-specified table properties that can identify a table, e.g. file path, + * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map. */ // TODO: this should take a Map as table properties. Table getTable(CaseInsensitiveStringMap options); + + /** + * Return a {@link Table} instance with the user-specified table schema and properties to do + * read/write. Implementations should infer the table partitioning. The returned table must report + * the same schema with the user-specified one, or Spark will fail the operation. + * + * @param schema The user-specified table schema. + * @param properties The user-specified table properties that can identify a table, e.g. file + * path, Kafka topic name, etc. The properties map may be + * {@link CaseInsensitiveStringMap}. + * + * @throws IllegalArgumentException if the user-specified schema does not match the actual table + * schema. + */ + Table getTable(StructType schema, Map properties); + + /** + * Return a {@link Table} instance with the user-specified table schema, partitioning and + * properties to do read/write. The returned table must report the same schema and partitioning + * with the user-specified ones, or Spark will fail the operation. + * + * @param schema The user-specified table schema. + * @param partitioning The user-specified table partitioning. + * @param properties The properties of the table to load. It should be sufficient to define and + * access a table. The properties map may be {@link CaseInsensitiveStringMap}. + * + * @throws IllegalArgumentException if the user-specified schema/partitioning does not match the + * actual table schema/partitioning. + */ + Table getTable( + StructType schema, + Transform[] partitioning, + Map properties); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index a7f9728cf2b80..be14b17701276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -77,15 +77,16 @@ class CatalogManager( // If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2 // session catalog. Otherwise, return the default session catalog. def v2SessionCatalog: CatalogPlugin = { - try { - catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog()) - } catch { - case NonFatal(_) => - logError( - "Fail to instantiate the custom v2 session catalog: " + - conf.getConfString(CatalogManager.SESSION_CATALOG_NAME)) - defaultSessionCatalog - } + conf.getConf(SQLConf.V2_SESSION_CATALOG).map { customV2SessionCatalog => + try { + catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog()) + } catch { + case NonFatal(_) => + logError( + "Fail to instantiate the custom v2 session catalog: " + customV2SessionCatalog) + defaultSessionCatalog + } + }.getOrElse(defaultSessionCatalog) } private def getDefaultNamespace(c: CatalogPlugin) = c match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index cb23433b989a3..8cf0558b2c25a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.types.StructType /** * An implementation of catalog v2 `Table` to expose v1 table metadata. */ -private[sql] case class V1Table(v1Table: CatalogTable) extends Table { - assert(v1Table.provider.isDefined) +private[sql] case class V1Table(catalogTable: CatalogTable) extends Table { + assert(catalogTable.provider.isDefined) implicit class IdentifierHelper(identifier: TableIdentifier) { def quoted: String = { @@ -52,34 +52,34 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { } } - override lazy val properties: util.Map[String, String] = { - val pathOption = v1Table.storage.locationUri match { + lazy val options: Map[String, String] = { + catalogTable.storage.locationUri match { case Some(uri) => - Some("path" -> uri.toString) + catalogTable.storage.properties + ("path" -> uri.toString) case _ => - None + catalogTable.storage.properties } - val providerOption = "provider" -> v1Table.provider.get - (v1Table.storage.properties ++ v1Table.properties ++ pathOption + providerOption).asJava } - override def schema: StructType = v1Table.schema + override lazy val properties: util.Map[String, String] = catalogTable.properties.asJava + + override lazy val schema: StructType = catalogTable.schema override lazy val partitioning: Array[Transform] = { val partitions = new mutable.ArrayBuffer[Transform]() - v1Table.partitionColumnNames.foreach { col => + catalogTable.partitionColumnNames.foreach { col => partitions += LogicalExpressions.identity(col) } - v1Table.bucketSpec.foreach { spec => + catalogTable.bucketSpec.foreach { spec => partitions += LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) } partitions.toArray } - override def name: String = v1Table.identifier.quoted + override def name: String = catalogTable.identifier.quoted override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bfc7871d6bcdf..3d28b5e93a17e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,8 +1982,7 @@ object SQLConf { "passed the Spark built-in session catalog, so that it may delegate calls to the " + "built-in session catalog.") .stringConf - .createWithDefault( - "org.apache.spark.sql.execution.datasources.v2.CatalogExtensionForTableProvider") + .createOptional val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") .doc("When true, the upcast will be loose and allows string to atomic types.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 10638df9ca52b..94a1fb9cb78e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -215,12 +215,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => - DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, schema, dsOptions) - case _ => - provider.getTable(dsOptions) - } + val table = DataSourceV2Utils.loadTableFromTableProvider( + provider, source, userSpecifiedSchema, dsOptions) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supports(BATCH_READ) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 3f4f29c3e135a..0a13ddaebabbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.sources.DataSourceRegister @@ -36,6 +37,19 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read noop source with user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read noop source with user-specified schema/partitioning.") + } } private[noop] object NoopTable extends Table with SupportsWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala deleted file mode 100644 index 7f38d3ff34d7a..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CatalogExtensionForTableProvider.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util - -import scala.util.control.NonFatal - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{DelegatingCatalogExtension, Identifier, SupportsSpecifiedSchemaPartitioning, Table} -import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -class CatalogExtensionForTableProvider extends DelegatingCatalogExtension { - - private val conf = SQLConf.get - - override def loadTable(ident: Identifier): Table = { - val table = super.loadTable(ident) - tryResolveTableProvider(table) - } - - override def createTable( - ident: Identifier, - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { - val provider = properties.getOrDefault("provider", conf.defaultDataSourceName) - val maybeProvider = DataSource.lookupDataSourceV2(provider, conf) - val (actualSchema, actualPartitioning) = if (maybeProvider.isDefined && schema.isEmpty) { - // A sanity check. The parser should guarantee it. - assert(partitions.isEmpty) - // If `CREATE TABLE ... USING` does not specify table metadata, get the table metadata from - // data source first. - val table = maybeProvider.get.getTable(new CaseInsensitiveStringMap(properties)) - table.schema() -> table.partitioning() - } else { - schema -> partitions - } - super.createTable(ident, actualSchema, actualPartitioning, properties) - // call `loadTable` to make sure the schema/partitioning specified in `CREATE TABLE ... USING` - // matches the actual data schema/partitioning. If error happens during table loading, drop - // the table. - try { - loadTable(ident) - } catch { - case NonFatal(e) => - dropTable(ident) - throw e - } - } - - private def tryResolveTableProvider(table: Table): Table = { - val providerName = table.properties().get("provider") - assert(providerName != null) - DataSource.lookupDataSourceV2(providerName, conf).map { - // TODO: support file source v2 in CREATE TABLE USING. - case _: FileDataSourceV2 => table - - case s: SupportsSpecifiedSchemaPartitioning => - s.getTable(table.schema, table.partitioning, table.properties) - - case provider => - val actualTable = provider.getTable(new CaseInsensitiveStringMap(table.properties)) - if (actualTable.schema() != table.schema) { - throw new AnalysisException(s"Table provider '$providerName' returns a table " + - "which has inappropriate schema:\n" + - s"schema in Spark meta-store:\t${table.schema}\n" + - s"schema from table provider:\t${actualTable.schema}") - } - if (!actualTable.partitioning.sameElements(table.partitioning)) { - throw new AnalysisException(s"Table provider '$providerName' returns a table " + - "which has inappropriate partitioning:\n" + - s"partitioning in Spark meta-store: ${table.partitioning.mkString(", ")}\n" + - s"partitioning from table provider: ${actualTable.partitioning.mkString(", ")}") - } - actualTable - }.getOrElse(table) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 4b278ce2ea719..2e458103fada3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,7 +21,8 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsSpecifiedSchemaPartitioning, Table, TableProvider} +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -61,20 +62,41 @@ private[sql] object DataSourceV2Utils extends Logging { } } - def loadTableWithUserSpecifiedSchema( + def loadTableFromTableProvider( provider: TableProvider, - schema: StructType, + providerName: String, + userSpecifiedSchema: Option[StructType], options: CaseInsensitiveStringMap): Table = { - provider match { - case s: SupportsSpecifiedSchemaPartitioning => - // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified - // partitioning. - s.getTable(schema, Array.empty, options) + userSpecifiedSchema match { + case Some(schema) => + val table = provider.getTable(schema, options) + validateTableSchemaAndPartitioning(providerName, table, schema, table.partitioning()) + table case _ => - throw new UnsupportedOperationException( - provider.getClass.getSimpleName + " source does not support user-specified schema"); + provider.getTable(options) + // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified + // partitioning. + } + } + + def validateTableSchemaAndPartitioning( + providerName: String, + table: Table, + expectedSchema: StructType, + expectedPartitioning: Array[Transform]): Unit = { + if (table.schema() != expectedSchema) { + throw new AnalysisException(s"Table provider '$providerName' returns a table " + + "which has inappropriate schema:\n" + + s"Expected Schema: $expectedSchema\n" + + s"Actual Schema: ${table.schema}") + } + if (!table.partitioning().sameElements(expectedPartitioning)) { + throw new AnalysisException(s"Table provider '$providerName' returns a table " + + "which has inappropriate partitioning:\n" + + s"Expected Partitioning: ${expectedPartitioning.mkString(", ")}\n" + + s"Actual Partitioning: ${table.partitioning().mkString(", ")}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index 659f1af8452a5..71d1d5e527f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,20 +16,23 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util + import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.{SupportsSpecifiedSchemaPartitioning, TableProvider} +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** * A base interface for data source v2 implementations of the built-in file-based data sources. */ -trait FileDataSourceV2 extends TableProvider - with SupportsSpecifiedSchemaPartitioning with DataSourceRegister { +trait FileDataSourceV2 extends TableProvider with DataSourceRegister { /** * Returns a V1 [[FileFormat]] class of the same file data source. @@ -60,4 +63,12 @@ trait FileDataSourceV2 extends TableProvider val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "file source v2 does not support user-specified partitioning yet.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index e71aea93f263f..e6202b267b56c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -42,7 +42,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import V2SessionCatalog._ - override def defaultNamespace: Array[String] = Array("default") + override val defaultNamespace: Array[String] = Array("default") override def name: String = CatalogManager.SESSION_CATALOG_NAME @@ -69,6 +69,20 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) } } + private def tryResolveTableProvider(v1Table: V1Table): Table = { + val providerName = v1Table.catalogTable.provider.get + DataSource.lookupDataSourceV2(providerName, conf).map { + // TODO: support file source v2 in CREATE TABLE USING. + case _: FileDataSourceV2 => v1Table + + case provider => + val table = provider.getTable(v1Table.schema, v1Table.partitioning, v1Table.properties) + DataSourceV2Utils.validateTableSchemaAndPartitioning( + providerName, table, v1Table.schema, v1Table.partitioning) + table + }.getOrElse(v1Table) + } + override def loadTable(ident: Identifier): Table = { val catalogTable = try { catalog.getTableMetadata(ident.asTableIdentifier) @@ -81,7 +95,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) throw new NoSuchTableException(ident) } - V1Table(catalogTable) + tryResolveTableProvider(V1Table(catalogTable)) } override def invalidateTable(ident: Identifier): Unit = { @@ -94,8 +108,25 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) partitions: Array[Transform], properties: util.Map[String, String]): Table = { - val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(partitions) - val provider = properties.getOrDefault("provider", conf.defaultDataSourceName) + val providerName = properties.getOrDefault("provider", conf.defaultDataSourceName) + // It's guaranteed that we only call `V2SessionCatalog.createTable` if the table provider is v2. + val provider = DataSource.lookupDataSourceV2(providerName, conf).get + val (actualSchema, actualPartitioning) = if (schema.isEmpty) { + // A sanity check. The parser should guarantee it. + assert(partitions.isEmpty) + // If `CREATE TABLE ... USING` does not specify table metadata, get the table metadata from + // data source first. + val table = provider.getTable(new CaseInsensitiveStringMap(properties)) + table.schema() -> table.partitioning() + } else { + // The schema/partitioning is specified in `CREATE TABLE ... USING`, validate it. + val table = provider.getTable(schema, partitions, properties) + DataSourceV2Utils.validateTableSchemaAndPartitioning( + providerName, table, schema, partitions) + schema -> partitions + } + + val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(actualPartitioning) val tableProperties = properties.asScala val location = Option(properties.get(LOCATION_TABLE_PROP)) val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) @@ -106,8 +137,8 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) identifier = ident.asTableIdentifier, tableType = tableType, storage = storage, - schema = schema, - provider = Some(provider), + schema = actualSchema, + provider = Some(providerName), partitionColumnNames = partitionColumns, bucketSpec = maybeBucketSpec, properties = tableProperties.toMap, @@ -121,7 +152,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) throw new TableAlreadyExistsException(ident) } - V1Table(tableDesc) + loadTable(ident) } override def alterTable( @@ -136,16 +167,15 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) val properties = CatalogV2Util.applyPropertiesChanges(catalogTable.properties, changes) val schema = CatalogV2Util.applySchemaChanges(catalogTable.schema, changes) - val updatedTable = catalogTable.copy(properties = properties, schema = schema) try { - catalog.alterTable(updatedTable) + catalog.alterTable(catalogTable.copy(properties = properties, schema = schema)) } catch { case _: NoSuchTableException => throw new NoSuchTableException(ident) } - V1Table(updatedTable) + loadTable(ident) } override def dropTable(ident: Identifier): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index f826869a2d982..cda654f22e329 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.csv import java.util import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 @@ -38,10 +37,7 @@ class CSVDataSourceV2 extends FileDataSourceV2 { CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) CSVTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala index d5d35f77a6ebb..fd019c2d144a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.json import java.util import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.v2._ @@ -38,10 +37,7 @@ class JsonDataSourceV2 extends FileDataSourceV2 { JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) JsonTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 5e2ffdc43ca8d..ae45030792a96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.orc import java.util import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ @@ -38,10 +37,7 @@ class OrcDataSourceV2 extends FileDataSourceV2 { OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) OrcTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala index b3bc5eb575997..85d7623fd42c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import java.util import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2._ @@ -38,10 +37,7 @@ class ParquetDataSourceV2 extends FileDataSourceV2 { ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) ParquetTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala index fe15e55351686..c683097b47bfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.text import java.util import org.apache.spark.sql.connector.catalog.Table -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 @@ -38,10 +37,7 @@ class TextDataSourceV2 extends FileDataSourceV2 { TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { val paths = getPaths(properties) val tableName = getTableName(paths) TextTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 20eb7ae5a6d96..11b6aa6955243 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.{SupportsTruncate, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite @@ -43,6 +44,19 @@ class ConsoleSinkProvider extends TableProvider ConsoleTable } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read console sink with user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read console sink with user-specified schema/partitioning.") + } + def createRelation( sqlContext: SQLContext, mode: SaveMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 911a526428cf4..4961b9460993d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.internal.SQLConf @@ -98,6 +99,19 @@ object MemoryStreamTableProvider extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "memory stream does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "memory stream does not support user-specified schema/partitioning.") + } } class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3f7b0377f1eab..9aefd7763c779 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream @@ -72,6 +73,19 @@ class RateStreamProvider extends TableProvider with DataSourceRegister { new RateStreamTable(rowsPerSecond, rampUpTimeSeconds, numPartitions) } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Rate stream source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Rate stream source does not support user-specified schema/partitioning.") + } + override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index fae3cb765c0c9..e91b145e102f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream @@ -63,6 +64,19 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit options.getBoolean("includeTimestamp", false)) } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "socket source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "socket source does not support user-specified schema/partitioning.") + } + /** String that represents the format that this data source provider uses. */ override def shortName(): String = "socket" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 65e98b2e833fd..6eb606d386a27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -187,12 +187,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => - DataSourceV2Utils.loadTableWithUserSpecifiedSchema(provider, schema, dsOptions) - case _ => - provider.getTable(dsOptions) - } + val table = DataSourceV2Utils.loadTableFromTableProvider( + provider, source, userSpecifiedSchema, dsOptions) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f0..fcfde41663926 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -22,15 +22,15 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaAdvancedDataSourceV2 implements TableProvider { +public class JavaAdvancedDataSourceV2 extends TestingV2Source { @Override public Table getTable(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 76da45e182b3c..416ffd125abf1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -20,8 +20,8 @@ import java.io.IOException; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -33,7 +33,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; -public class JavaColumnarDataSourceV2 implements TableProvider { +public class JavaColumnarDataSourceV2 extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index fbbc457b2945d..3b9e251ee9743 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -22,17 +22,17 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution; import org.apache.spark.sql.connector.read.partitioning.Distribution; import org.apache.spark.sql.connector.read.partitioning.Partitioning; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaPartitionAwareDataSource implements TableProvider { +public class JavaPartitionAwareDataSource extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportPartitioning { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index 49438fe668d56..ba9b6201eca92 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -19,15 +19,15 @@ import java.util.OptionalLong; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.Statistics; import org.apache.spark.sql.connector.read.SupportsReportStatistics; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaReportStatisticsDataSource implements TableProvider { +public class JavaReportStatisticsDataSource extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { @Override public Statistics estimateStatistics() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index a9ac9e77edf86..c1214bf76fa63 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -19,7 +19,6 @@ import java.util.Map; -import org.apache.spark.sql.connector.catalog.SupportsSpecifiedSchemaPartitioning; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; @@ -28,8 +27,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaSchemaRequiredDataSource - implements TableProvider, SupportsSpecifiedSchemaPartitioning { +public class JavaSchemaRequiredDataSource implements TableProvider { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -51,10 +49,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable( - StructType schema, - Transform[] partitions, - Map properties) { + public Table getTable(StructType schema, Map properties) { return new JavaSimpleBatchTable() { @Override @@ -69,6 +64,15 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { }; } + @Override + public Table getTable( + StructType schema, + Transform[] partitioning, + Map properties) { + assert partitioning.length == 0; + return getTable(schema, properties); + } + @Override public Table getTable(CaseInsensitiveStringMap options) { throw new IllegalArgumentException("requires a user-supplied schema"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java index 97b00477e1764..71cf97b56fe54 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Set; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; @@ -34,7 +35,7 @@ abstract class JavaSimpleBatchTable implements Table, SupportsRead { @Override public StructType schema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 8b6d71b986ff7..7d48e5aa4ee0b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -17,13 +17,13 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaSimpleDataSourceV2 implements TableProvider { +public class JavaSimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java index 7cbba00420928..bdd9dd3ea0ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java @@ -17,6 +17,7 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; @@ -37,7 +38,7 @@ public Batch toBatch() { @Override public StructType readSchema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 207ece83759ed..395c37cb1f0c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -91,12 +91,6 @@ class DataSourceV2DataFrameSessionCatalogSuite } } -class InMemoryTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("D'oh!") - } -} - class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { override def newTable( name: String, @@ -130,7 +124,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable } } -private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] +private[connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -139,7 +133,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio spark.sessionState.catalogManager.catalog(name) } - protected val v2Format: String = classOf[InMemoryTableProvider].getName + protected val v2Format: String = classOf[FakeV2Provider].getName protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d353e6b3f56d8..213d496324460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.connector +import java.util + import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG import org.apache.spark.sql.sources.SimpleScanSource @@ -1163,6 +1166,17 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("Unnecessary for DDL tests") + throw new UnsupportedOperationException() + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException() + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala index c40de492e137e..46c0c7c1a9d95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SupportsSpecifiedSchemaPartitioning, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -173,8 +173,7 @@ class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { } } - // TODO: enable it when DELETE FROM supports v2 session catalog. - ignore("DELETE FROM") { + test("DELETE FROM") { withTable("t") { sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") sql("INSERT INTO t SELECT 1, 1") @@ -193,8 +192,7 @@ class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { sql("INSERT INTO t SELECT 1, -1") checkAnswer(spark.table("t"), Row(1, -1)) - sql("ALTER TABLE t DROP COLUMN i") - val e = intercept[AnalysisException](sql("SELECT * FROM t")) + val e = intercept[AnalysisException](sql("ALTER TABLE t DROP COLUMN i")) assert(e.message.contains("returns a table which has inappropriate schema")) } } @@ -225,6 +223,17 @@ class ReadWriteV2Source extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { ReadWriteV2Source.table } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } } object PartitionedV2Source { @@ -242,6 +251,17 @@ class PartitionedV2Source extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { PartitionedV2Source.table } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } } class ReadOnlyV2Source extends TableProvider { @@ -258,6 +278,17 @@ class ReadOnlyV2Source extends TableProvider { rows.withRow(InternalRow(1)).withRow(InternalRow(-1)) table.withData(Array(rows)) } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } } class WriteOnlyV2Source extends TableProvider { @@ -271,13 +302,28 @@ class WriteOnlyV2Source extends TableProvider { } } } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } } -class DynamicSchemaV2Source extends TableProvider with SupportsSpecifiedSchemaPartitioning { +class DynamicSchemaV2Source extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { new InMemoryTable("dynamic-schema", new StructType(), Array.empty, options) {} } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", schema, Array.empty, properties) {} + } + override def getTable( schema: StructType, partitions: Array[Transform], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 17befa47f0bd9..a0b56985ac0fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -28,7 +28,7 @@ import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsSpecifiedSchemaPartitioning, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read._ @@ -416,9 +416,29 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } +abstract class TestingV2Source extends TableProvider { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + assert(schema == TestingV2Source.schema) + getTable(new CaseInsensitiveStringMap(properties)) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + assert(schema == TestingV2Source.schema) + assert(partitioning.isEmpty) + getTable(new CaseInsensitiveStringMap(properties)) + } +} + +object TestingV2Source { + val schema = new StructType().add("i", "int").add("j", "int") +} + abstract class SimpleBatchTable extends Table with SupportsRead { - override def schema(): StructType = new StructType().add("i", "int").add("j", "int") + override def schema(): StructType = TestingV2Source.schema override def name(): String = this.getClass.toString @@ -432,12 +452,12 @@ abstract class SimpleScanBuilder extends ScanBuilder override def toBatch: Batch = this - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def readSchema(): StructType = TestingV2Source.schema override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory } -class SimpleSinglePartitionSource extends TableProvider { +class SimpleSinglePartitionSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -454,7 +474,7 @@ class SimpleSinglePartitionSource extends TableProvider { // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends TableProvider { +class SimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -469,7 +489,7 @@ class SimpleDataSourceV2 extends TableProvider { } } -class AdvancedDataSourceV2 extends TableProvider { +class AdvancedDataSourceV2 extends TestingV2Source { override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { @@ -559,7 +579,7 @@ class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderF } -class SchemaRequiredDataSource extends TableProvider with SupportsSpecifiedSchemaPartitioning { +class SchemaRequiredDataSource extends TableProvider { class MyScanBuilder(schema: StructType) extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = Array.empty @@ -572,9 +592,8 @@ class SchemaRequiredDataSource extends TableProvider with SupportsSpecifiedSchem } override def getTable( - schema: StructType, - partitions: Array[Transform], - options: util.Map[String, String]): Table = { + schema: StructType, + properties: util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema @@ -584,9 +603,16 @@ class SchemaRequiredDataSource extends TableProvider with SupportsSpecifiedSchem } } } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(schema, properties) + } } -class ColumnarDataSourceV2 extends TableProvider { +class ColumnarDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { @@ -651,7 +677,7 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } } -class PartitionAwareDataSource extends TableProvider { +class PartitionAwareDataSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder with SupportsReportPartitioning{ @@ -719,7 +745,7 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { } } -class ReportStatisticsDataSource extends TableProvider { +class ReportStatisticsDataSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder with SupportsReportStatistics { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 079853a653193..0f45d6f9a4820 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.connector.write.WriteBuilder import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution} @@ -47,10 +46,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { new DummyReadOnlyFileTable } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException } } @@ -78,10 +74,7 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { new DummyWriteOnlyFileTable } - override def getTable( - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 22d3750022c57..baba8fe95e9f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TestingV2Source with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce6d56cf84df1..99bff90700749 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -39,7 +39,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { StreamingRelationV2( - TestTableProvider, + new FakeV2Provider, "fake", table, CaseInsensitiveStringMap.empty(), @@ -203,12 +203,6 @@ private case object TestRelation extends LeafNode with NamedRelation { override def output: Seq[AttributeReference] = TableCapabilityCheckSuite.schema.toAttributes } -private object TestTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException - } -} - private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index de843ba4375d0..1cea1448f5eda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -176,6 +176,7 @@ class InMemoryV1Provider extends TableProvider with DataSourceRegister with CreatableRelationProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { InMemoryV1Provider.tables.getOrElse(options.get("name"), { @@ -188,6 +189,17 @@ class InMemoryV1Provider }) } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + override def shortName(): String = "in-memory" override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 0f4fe656dd20a..03f2f888f04c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.expressions.{EqualTo, IntegerLiteral, StringLiteral} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelect, CreateV2Table, DescribeTable, DropTable, LogicalPlan, SubqueryAlias, UpdateTable} -import org.apache.spark.sql.connector.InMemoryTableProvider +import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ - private val v2Format = classOf[InMemoryTableProvider].getName + private val v2Format = classOf[FakeV2Provider].getName private val table: Table = { val t = mock(classOf[Table]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index f0bf6d2b06120..e82c4518cce93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -193,6 +193,18 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { } } + test("user-specified schema given") { + val provider = new TextSocketSourceProvider + val userSpecifiedSchema = StructType( + StructField("name", StringType) :: + StructField("area", StringType) :: Nil) + val params = Map("host" -> "localhost", "port" -> "1234") + val exception = intercept[UnsupportedOperationException] { + provider.getTable(userSpecifiedSchema, params.asJava) + } + assert(exception.getMessage.contains("socket source does not support user-specified schema")) + } + test("input row metrics") { serverThread = new ServerThread() serverThread.start() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index e9d148c38e6cb..9e5bb586d5461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReaderFactory, ContinuousStream, MicroBatchStream, Offset, PartitionOffset} import org.apache.spark.sql.connector.write.{WriteBuilder, WriterCommitMessage} @@ -90,9 +91,22 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { } } +trait FakeTableProvider extends TableProvider { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } +} + class FakeReadMicroBatchOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" @@ -115,7 +129,7 @@ class FakeReadMicroBatchOnly class FakeReadContinuousOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" @@ -136,7 +150,7 @@ class FakeReadContinuousOnly } } -class FakeReadBothModes extends DataSourceRegister with TableProvider { +class FakeReadBothModes extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-microbatch-continuous" override def getTable(options: CaseInsensitiveStringMap): Table = { @@ -153,7 +167,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { } } -class FakeReadNeitherMode extends DataSourceRegister with TableProvider { +class FakeReadNeitherMode extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-neither-mode" override def getTable(options: CaseInsensitiveStringMap): Table = { @@ -167,7 +181,7 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { class FakeWriteOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" @@ -182,7 +196,7 @@ class FakeWriteOnly } } -class FakeNoWrite extends DataSourceRegister with TableProvider { +class FakeNoWrite extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-write-neither-mode" override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { @@ -200,7 +214,7 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with TableProvider with StreamSinkProvider { + with FakeTableProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, From 1235d78badc358db0ed586f9419a2bf766a4e22b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Oct 2019 21:26:07 +0800 Subject: [PATCH 4/5] update --- .../spark/sql/v2/avro/AvroDataSourceV2.scala | 6 +-- .../sql/kafka010/KafkaSourceProvider.scala | 4 +- .../sql/connector/catalog/TableProvider.java | 11 +++-- .../apache/spark/sql/DataFrameReader.scala | 6 ++- .../datasources/noop/NoopDataSource.scala | 2 +- .../datasources/v2/DataSourceV2Utils.scala | 44 +------------------ .../datasources/v2/V2SessionCatalog.scala | 24 +++++++++- .../datasources/v2/csv/CSVDataSourceV2.scala | 6 +-- .../v2/json/JsonDataSourceV2.scala | 6 +-- .../datasources/v2/orc/OrcDataSourceV2.scala | 6 +-- .../v2/parquet/ParquetDataSourceV2.scala | 6 +-- .../v2/text/TextDataSourceV2.scala | 6 +-- .../sql/execution/streaming/console.scala | 2 +- .../sql/execution/streaming/memory.scala | 2 +- .../sources/RateStreamProvider.scala | 11 ++--- .../sources/TextSocketSourceProvider.scala | 18 ++++---- .../sql/streaming/DataStreamReader.scala | 6 ++- .../connector/JavaAdvancedDataSourceV2.java | 2 +- .../connector/JavaColumnarDataSourceV2.java | 3 +- .../JavaPartitionAwareDataSource.java | 3 +- .../JavaReportStatisticsDataSource.java | 3 +- .../JavaSchemaRequiredDataSource.java | 2 +- .../sql/connector/JavaSimpleDataSourceV2.java | 4 +- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- .../connector/DataSourceV2SQLUsingSuite.scala | 18 ++++---- .../sql/connector/DataSourceV2Suite.scala | 18 ++++---- .../FileDataSourceV2FallBackSuite.scala | 4 +- .../connector/SimpleWritableDataSource.scala | 8 ++-- .../sql/connector/V1WriteFallbackSuite.scala | 7 ++- .../sources/StreamingDataSourceV2Suite.scala | 32 +++++++------- 30 files changed, 130 insertions(+), 142 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index faf6245cafd62..5fc371b3b4437 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -31,10 +31,10 @@ class AvroDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "avro" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + AvroTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 01134efb427b1..6b61dfe3311ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -122,8 +122,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister "Kafka source does not support user-specified schema/partitioning.") } - override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { - val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + override def getTable(properties: ju.Map[String, String]): KafkaTable = { + val includeHeaders = Option(properties.get(INCLUDE_HEADERS)).map(_.toBoolean).getOrElse(false) new KafkaTable(includeHeaders) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index c003fcd924f74..a57e35507aca2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -42,16 +42,15 @@ public interface TableProvider { * Return a {@link Table} instance with the user-specified table properties to do read/write. * Implementations should infer the table schema and partitioning. * - * @param options The user-specified table properties that can identify a table, e.g. file path, - * Kafka topic name, etc. It's an immutable case-insensitive string-to-string map. + * @param properties The user-specified table properties that can identify a table, e.g. file + * path, Kafka topic name, etc. The properties map may be + * {@link CaseInsensitiveStringMap}. */ - // TODO: this should take a Map as table properties. - Table getTable(CaseInsensitiveStringMap options); + Table getTable(Map properties); /** * Return a {@link Table} instance with the user-specified table schema and properties to do - * read/write. Implementations should infer the table partitioning. The returned table must report - * the same schema with the user-specified one, or Spark will fail the operation. + * read/write. Implementations should infer the table partitioning. * * @param schema The user-specified table schema. * @param properties The user-specified table properties that can identify a table, e.g. file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 94a1fb9cb78e8..1a9c08b30b769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -215,8 +215,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = DataSourceV2Utils.loadTableFromTableProvider( - provider, source, userSpecifiedSchema, dsOptions) + val table = userSpecifiedSchema match { + case Some(schema) => provider.getTable(schema, dsOptions) + case _ => provider.getTable(dsOptions) + } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supports(BATCH_READ) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 0a13ddaebabbc..7f445f2d280ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable + override def getTable(properties: util.Map[String, String]): Table = NoopTable override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 2e458103fada3..52294ae2cb851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -20,12 +20,8 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.regex.Pattern import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, Table, TableProvider} -import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, TableProvider} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object DataSourceV2Utils extends Logging { @@ -61,42 +57,4 @@ private[sql] object DataSourceV2Utils extends Logging { case _ => Map.empty } } - - def loadTableFromTableProvider( - provider: TableProvider, - providerName: String, - userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap): Table = { - userSpecifiedSchema match { - case Some(schema) => - val table = provider.getTable(schema, options) - validateTableSchemaAndPartitioning(providerName, table, schema, table.partitioning()) - table - - case _ => - provider.getTable(options) - - // TODO: `DataFrameReader`/`DataStreamReader` should have an API to set user-specified - // partitioning. - } - } - - def validateTableSchemaAndPartitioning( - providerName: String, - table: Table, - expectedSchema: StructType, - expectedPartitioning: Array[Transform]): Unit = { - if (table.schema() != expectedSchema) { - throw new AnalysisException(s"Table provider '$providerName' returns a table " + - "which has inappropriate schema:\n" + - s"Expected Schema: $expectedSchema\n" + - s"Actual Schema: ${table.schema}") - } - if (!table.partitioning().sameElements(expectedPartitioning)) { - throw new AnalysisException(s"Table provider '$providerName' returns a table " + - "which has inappropriate partitioning:\n" + - s"Expected Partitioning: ${expectedPartitioning.mkString(", ")}\n" + - s"Actual Partitioning: ${table.partitioning().mkString(", ")}") - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index e6202b267b56c..e8eb868b59e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -23,6 +23,7 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} @@ -77,12 +78,31 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) case provider => val table = provider.getTable(v1Table.schema, v1Table.partitioning, v1Table.properties) - DataSourceV2Utils.validateTableSchemaAndPartitioning( + validateTableSchemaAndPartitioning( providerName, table, v1Table.schema, v1Table.partitioning) table }.getOrElse(v1Table) } + private def validateTableSchemaAndPartitioning( + providerName: String, + table: Table, + metaStoreSchema: StructType, + metaStorePartitioning: Array[Transform]): Unit = { + if (table.schema() != metaStoreSchema) { + throw new AnalysisException(s"Table provider '$providerName' reports a different data " + + "schema from the one in Spark meta-store:\n" + + s"Schema in Spark meta-store: $metaStoreSchema\n" + + s"Actual data schema: ${table.schema}") + } + if (!table.partitioning().sameElements(metaStorePartitioning)) { + throw new AnalysisException(s"Table provider '$providerName' reports a different data " + + "partitioning from the one in Spark meta-store:\n" + + s"Partitioning in Spark meta-store:: ${metaStorePartitioning.mkString(", ")}\n" + + s"Actual data partitioning: ${table.partitioning().mkString(", ")}") + } + } + override def loadTable(ident: Identifier): Table = { val catalogTable = try { catalog.getTableMetadata(ident.asTableIdentifier) @@ -121,7 +141,7 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) } else { // The schema/partitioning is specified in `CREATE TABLE ... USING`, validate it. val table = provider.getTable(schema, partitions, properties) - DataSourceV2Utils.validateTableSchemaAndPartitioning( + validateTableSchemaAndPartitioning( providerName, table, schema, partitions) schema -> partitions } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index cda654f22e329..ed2a8e50f04fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -31,10 +31,10 @@ class CSVDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "csv" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + CSVTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala index fd019c2d144a1..5ab6024a27aee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala @@ -31,10 +31,10 @@ class JsonDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "json" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + JsonTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index ae45030792a96..c53b99170db78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -31,10 +31,10 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + OrcTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala index 85d7623fd42c0..d17252c9480ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala @@ -31,10 +31,10 @@ class ParquetDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + ParquetTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala index c683097b47bfe..aade932566620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala @@ -31,10 +31,10 @@ class TextDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "text" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + TextTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 11b6aa6955243..61b6fba65026a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -40,7 +40,7 @@ class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { ConsoleTable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 4961b9460993d..f1615b1d518ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -96,7 +96,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 9aefd7763c779..df95bcde482e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -49,14 +49,14 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ - override def getTable(options: CaseInsensitiveStringMap): Table = { - val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) + override def getTable(properties: util.Map[String, String]): Table = { + val rowsPerSecond = Option(properties.get(ROWS_PER_SECOND)).map(_.toLong).getOrElse(1L) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") } - val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME)) + val rampUpTimeSeconds = Option(properties.get(RAMP_UP_TIME)) .map(JavaUtils.timeStringAsSec) .getOrElse(0L) if (rampUpTimeSeconds < 0) { @@ -64,8 +64,9 @@ class RateStreamProvider extends TableProvider with DataSourceRegister { s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") } - val numPartitions = options.getInt( - NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism) + val numPartitions = Option(properties.get(NUM_PARTITIONS)).map(_.toInt).getOrElse { + SparkSession.active.sparkContext.defaultParallelism + } if (numPartitions <= 0) { throw new IllegalArgumentException( s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index e91b145e102f1..465e56c806017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: CaseInsensitiveStringMap): Unit = { + private def checkParameters(params: util.Map[String, String]): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") if (!params.containsKey("host")) { @@ -47,7 +47,7 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - params.getBoolean("includeTimestamp", false) + Option(params.get("includeTimestamp")).foreach(_.toBoolean) } match { case Success(_) => case Failure(_) => @@ -55,13 +55,15 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } } - override def getTable(options: CaseInsensitiveStringMap): Table = { - checkParameters(options) + override def getTable(properties: util.Map[String, String]): Table = { + checkParameters(properties) new TextSocketTable( - options.get("host"), - options.getInt("port", -1), - options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), - options.getBoolean("includeTimestamp", false)) + properties.get("host"), + Option(properties.get("port")).map(_.toInt).getOrElse(-1), + Option(properties.get("numPartitions")).map(_.toInt).getOrElse { + SparkSession.active.sparkContext.defaultParallelism + }, + Option(properties.get("includeTimestamp")).map(_.toBoolean).getOrElse(false)) } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 6eb606d386a27..796ab1997acb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -187,8 +187,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) - val table = DataSourceV2Utils.loadTableFromTableProvider( - provider, source, userSpecifiedSchema, dsOptions) + val table = userSpecifiedSchema match { + case Some(schema) => provider.getTable(schema, dsOptions) + case _ => provider.getTable(dsOptions) + } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index fcfde41663926..14429a06ca0c9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -33,7 +33,7 @@ public class JavaAdvancedDataSourceV2 extends TestingV2Source { @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 416ffd125abf1..e5c74b3679450 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql.connector; import java.io.IOException; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.TestingV2Source; @@ -52,7 +53,7 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index 3b9e251ee9743..556507e255c68 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -56,7 +57,7 @@ public Partitioning outputPartitioning() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public Transform[] partitioning() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index ba9b6201eca92..fd8819ae52e96 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -17,6 +17,7 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; import java.util.OptionalLong; import org.apache.spark.sql.connector.TestingV2Source; @@ -54,7 +55,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index c1214bf76fa63..9923b8c34353a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -74,7 +74,7 @@ public Table getTable( } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 7d48e5aa4ee0b..32c23381100ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -23,6 +23,8 @@ import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import java.util.Map; + public class JavaSimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -37,7 +39,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index a6e98f4bc9c49..0f83fee94e886 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1178,7 +1178,7 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { throw new UnsupportedOperationException() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala index 46c0c7c1a9d95..0a8c74aae4e80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala @@ -55,7 +55,7 @@ class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { val e = intercept[AnalysisException]( sql(s"CREATE TABLE t(i INT) USING ${classOf[ReadWriteV2Source].getName}") ) - assert(e.getMessage.contains("returns a table which has inappropriate schema")) + assert(e.getMessage.contains("reports a different data schema")) sql(s"CREATE TABLE t(i INT, j INT) USING ${classOf[ReadWriteV2Source].getName}") sql("INSERT INTO t SELECT 1, -1") @@ -72,7 +72,7 @@ class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { |PARTITIONED BY (i) """.stripMargin) ) - assert(e.getMessage.contains("returns a table which has inappropriate partitioning")) + assert(e.getMessage.contains("reports a different data partitioning")) sql( s""" @@ -193,7 +193,7 @@ class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { checkAnswer(spark.table("t"), Row(1, -1)) val e = intercept[AnalysisException](sql("ALTER TABLE t DROP COLUMN i")) - assert(e.message.contains("returns a table which has inappropriate schema")) + assert(e.message.contains("reports a different data schema")) } } @@ -220,7 +220,7 @@ object ReadWriteV2Source { class ReadWriteV2Source extends TableProvider { // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data // of in-memory table, we keep the table instance in an object. - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { ReadWriteV2Source.table } @@ -248,7 +248,7 @@ object PartitionedV2Source { class PartitionedV2Source extends TableProvider { // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data // of in-memory table, we keep the table instance in an object. - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { PartitionedV2Source.table } @@ -265,7 +265,7 @@ class PartitionedV2Source extends TableProvider { } class ReadOnlyV2Source extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { val schema = new StructType().add("i", "int") val partitions = Array.empty[Transform] val properties = util.Collections.emptyMap[String, String] @@ -292,7 +292,7 @@ class ReadOnlyV2Source extends TableProvider { } class WriteOnlyV2Source extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { val schema = new StructType().add("i", "int") val partitions = Array.empty[Transform] val properties = util.Collections.emptyMap[String, String] @@ -316,8 +316,8 @@ class WriteOnlyV2Source extends TableProvider { } class DynamicSchemaV2Source extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - new InMemoryTable("dynamic-schema", new StructType(), Array.empty, options) {} + override def getTable(properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", new StructType(), Array.empty, properties) {} } override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a0b56985ac0fa..df0265a763ad6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -465,7 +465,7 @@ class SimpleSinglePartitionSource extends TestingV2Source { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -482,7 +482,7 @@ class SimpleDataSourceV2 extends TestingV2Source { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -491,7 +491,7 @@ class SimpleDataSourceV2 extends TestingV2Source { class AdvancedDataSourceV2 extends TestingV2Source { - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } @@ -587,7 +587,7 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } @@ -625,7 +625,7 @@ class ColumnarDataSourceV2 extends TestingV2Source { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -696,7 +696,7 @@ class PartitionAwareDataSource extends TestingV2Source { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -736,8 +736,8 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def getTable(options: CaseInsensitiveStringMap): Table = { - new MyTable(options) { + override def getTable(properties: util.Map[String, String]): Table = { + new MyTable(properties) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") } @@ -762,7 +762,7 @@ class ReportStatisticsDataSource extends TestingV2Source { } } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 0f45d6f9a4820..765ad0858805f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -42,7 +42,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new DummyReadOnlyFileTable } @@ -70,7 +70,7 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new DummyWriteOnlyFileTable } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index baba8fe95e9f2..3ee6c0f780f03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -131,10 +131,10 @@ class SimpleWritableDataSource extends TestingV2Source with SessionConfigSupport } } - class MyTable(options: CaseInsensitiveStringMap) + class MyTable(properties: util.Map[String, String]) extends SimpleBatchTable with SupportsWrite { - private val path = options.get("path") + private val path = properties.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration override def schema(): StructType = tableSchema @@ -151,8 +151,8 @@ class SimpleWritableDataSource extends TestingV2Source with SessionConfigSupport Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } - override def getTable(options: CaseInsensitiveStringMap): Table = { - new MyTable(options) + override def getTable(properties: util.Map[String, String]): Table = { + new MyTable(properties) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 1cea1448f5eda..de6b985cc7cdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -177,15 +177,14 @@ class InMemoryV1Provider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { - InMemoryV1Provider.tables.getOrElse(options.get("name"), { + InMemoryV1Provider.tables.getOrElse(properties.get("name"), { new InMemoryTableWithV1Fallback( "InMemoryTableWithV1Fallback", new StructType(), Array.empty, - options.asCaseSensitiveMap() - ) + properties) }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 9e5bb586d5461..6eb7933c638e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -93,14 +93,14 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { trait FakeTableProvider extends TableProvider { override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { - getTable(new CaseInsensitiveStringMap(properties)) + getTable(properties) } override def getTable( schema: StructType, partitioning: Array[Transform], properties: util.Map[String, String]): Table = { - getTable(new CaseInsensitiveStringMap(properties)) + getTable(properties) } } @@ -112,8 +112,8 @@ class FakeReadMicroBatchOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastReadOptions.options = properties new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -135,8 +135,8 @@ class FakeReadContinuousOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastReadOptions.options = properties new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -153,7 +153,7 @@ class FakeReadContinuousOnly class FakeReadBothModes extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -170,7 +170,7 @@ class FakeReadBothModes extends DataSourceRegister with FakeTableProvider { class FakeReadNeitherMode extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -187,8 +187,8 @@ class FakeWriteOnly override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastWriteOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastWriteOptions.options = properties new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -198,7 +198,7 @@ class FakeWriteOnly class FakeNoWrite extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-write-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -226,7 +226,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -235,7 +235,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister } object LastReadOptions { - var options: CaseInsensitiveStringMap = _ + var options: util.Map[String, String] = _ def clear(): Unit = { options = null @@ -243,7 +243,7 @@ object LastReadOptions { } object LastWriteOptions { - var options: CaseInsensitiveStringMap = _ + var options: util.Map[String, String] = _ def clear(): Unit = { options = null @@ -361,7 +361,7 @@ class StreamingDataSourceV2Suite extends StreamTest { eventually(timeout(streamingTimeout)) { // Write options should not be set. assert(!LastWriteOptions.options.containsKey(readOptionName)) - assert(LastReadOptions.options.getBoolean(readOptionName, false)) + assert(LastReadOptions.options.get(readOptionName) == "true") } } } @@ -372,7 +372,7 @@ class StreamingDataSourceV2Suite extends StreamTest { eventually(timeout(streamingTimeout)) { // Read options should not be set. assert(!LastReadOptions.options.containsKey(writeOptionName)) - assert(LastWriteOptions.options.getBoolean(writeOptionName, false)) + assert(LastWriteOptions.options.get(writeOptionName) == "true") } } } From cfbe0a75f80e88d4a5831785d05fb9b708c5ada3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Oct 2019 00:55:18 +0800 Subject: [PATCH 5/5] fix test --- .../apache/spark/sql/execution/command/DDLSuite.scala | 2 +- .../spark/sql/hive/execution/HiveDDLSuite.scala | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 70b1db8e5f0d2..0a0b79afbff67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2671,7 +2671,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") } - assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + assert(e.message.contains("'v1' is a view not a table")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 4253fe2e1edcb..406e41db791bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -769,6 +769,11 @@ class HiveDDLSuite } } + private def assertErrorForProcessViewAsTable(sqlText: String): Unit = { + val message = intercept[AnalysisException](sql(sqlText)).getMessage + assert(message.contains("Invalid command")) + } + private def assertErrorForAlterTableOnView(sqlText: String): Unit = { val message = intercept[AnalysisException](sql(sqlText)).getMessage assert(message.contains("Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) @@ -831,13 +836,13 @@ class HiveDDLSuite assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET SERDE 'whatever'")