diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala index c6f52d676422c..5fc371b3b4437 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.v2.avro +import java.util + import org.apache.spark.sql.avro.AvroFileFormat import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.FileFormat @@ -29,15 +31,15 @@ class AvroDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "avro" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + AvroTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - AvroTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + AvroTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index 765e5727d944a..952d9f55d88ad 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class AvroTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index c15f08d78741d..6b61dfe3311ba 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -31,6 +31,7 @@ import org.apache.spark.kafka010.KafkaConfigUpdater import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.{BatchWrite, WriteBuilder} @@ -108,8 +109,21 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParameters)) } - override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { - val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false) + override def getTable(schema: StructType, properties: ju.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Kafka source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: ju.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Kafka source does not support user-specified schema/partitioning.") + } + + override def getTable(properties: ju.Map[String, String]): KafkaTable = { + val includeHeaders = Option(properties.get(INCLUDE_HEADERS)).map(_.toBoolean).getOrElse(false) new KafkaTable(includeHeaders) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java index e9fd87d0e2d40..a57e35507aca2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableProvider.java @@ -17,7 +17,10 @@ package org.apache.spark.sql.connector.catalog; +import java.util.Map; + import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -36,26 +39,44 @@ public interface TableProvider { /** - * Return a {@link Table} instance to do read/write with user-specified options. + * Return a {@link Table} instance with the user-specified table properties to do read/write. + * Implementations should infer the table schema and partitioning. + * + * @param properties The user-specified table properties that can identify a table, e.g. file + * path, Kafka topic name, etc. The properties map may be + * {@link CaseInsensitiveStringMap}. + */ + Table getTable(Map properties); + + /** + * Return a {@link Table} instance with the user-specified table schema and properties to do + * read/write. Implementations should infer the table partitioning. + * + * @param schema The user-specified table schema. + * @param properties The user-specified table properties that can identify a table, e.g. file + * path, Kafka topic name, etc. The properties map may be + * {@link CaseInsensitiveStringMap}. * - * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. + * @throws IllegalArgumentException if the user-specified schema does not match the actual table + * schema. */ - Table getTable(CaseInsensitiveStringMap options); + Table getTable(StructType schema, Map properties); /** - * Return a {@link Table} instance to do read/write with user-specified schema and options. - *

- * By default this method throws {@link UnsupportedOperationException}, implementations should - * override this method to handle user-specified schema. - *

- * @param options the user-specified options that can identify a table, e.g. file path, Kafka - * topic name, etc. It's an immutable case-insensitive string-to-string map. - * @param schema the user-specified schema. - * @throws UnsupportedOperationException + * Return a {@link Table} instance with the user-specified table schema, partitioning and + * properties to do read/write. The returned table must report the same schema and partitioning + * with the user-specified ones, or Spark will fail the operation. + * + * @param schema The user-specified table schema. + * @param partitioning The user-specified table partitioning. + * @param properties The properties of the table to load. It should be sufficient to define and + * access a table. The properties map may be {@link CaseInsensitiveStringMap}. + * + * @throws IllegalArgumentException if the user-specified schema/partitioning does not match the + * actual table schema/partitioning. */ - default Table getTable(CaseInsensitiveStringMap options, StructType schema) { - throw new UnsupportedOperationException( - this.getClass().getSimpleName() + " source does not support user-specified schema"); - } + Table getTable( + StructType schema, + Transform[] partitioning, + Map properties); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index 616c3cf696396..8cf0558b2c25a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -30,7 +30,9 @@ import org.apache.spark.sql.types.StructType /** * An implementation of catalog v2 `Table` to expose v1 table metadata. */ -private[sql] case class V1Table(v1Table: CatalogTable) extends Table { +private[sql] case class V1Table(catalogTable: CatalogTable) extends Table { + assert(catalogTable.provider.isDefined) + implicit class IdentifierHelper(identifier: TableIdentifier) { def quoted: String = { identifier.database match { @@ -38,7 +40,6 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { Seq(db, identifier.table).map(quote).mkString(".") case _ => quote(identifier.table) - } } @@ -51,38 +52,36 @@ private[sql] case class V1Table(v1Table: CatalogTable) extends Table { } } - def catalogTable: CatalogTable = v1Table - lazy val options: Map[String, String] = { - v1Table.storage.locationUri match { + catalogTable.storage.locationUri match { case Some(uri) => - v1Table.storage.properties + ("path" -> uri.toString) + catalogTable.storage.properties + ("path" -> uri.toString) case _ => - v1Table.storage.properties + catalogTable.storage.properties } } - override lazy val properties: util.Map[String, String] = v1Table.properties.asJava + override lazy val properties: util.Map[String, String] = catalogTable.properties.asJava - override lazy val schema: StructType = v1Table.schema + override lazy val schema: StructType = catalogTable.schema override lazy val partitioning: Array[Transform] = { val partitions = new mutable.ArrayBuffer[Transform]() - v1Table.partitionColumnNames.foreach { col => + catalogTable.partitionColumnNames.foreach { col => partitions += LogicalExpressions.identity(col) } - v1Table.bucketSpec.foreach { spec => + catalogTable.bucketSpec.foreach { spec => partitions += LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) } partitions.toArray } - override def name: String = v1Table.identifier.quoted + override def name: String = catalogTable.identifier.quoted override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() - override def toString: String = s"UnresolvedTable($name)" + override def toString: String = s"V1Table($name)" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 414f9d5834868..59299b3157982 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -72,6 +72,8 @@ class InMemoryTable( this } + def clear(): Unit = dataMap.synchronized(dataMap.clear()) + override def capabilities: util.Set[TableCapability] = Set( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b9cc25817d2f3..1a9c08b30b769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -216,7 +216,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) + case Some(schema) => provider.getTable(schema, dsOptions) case _ => provider.getTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 3f4f29c3e135a..7f445f2d280ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.sources.DataSourceRegister @@ -35,7 +36,20 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap */ class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable + override def getTable(properties: util.Map[String, String]): Table = NoopTable + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read noop source with user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read noop source with user-specified schema/partitioning.") + } } private[noop] object NoopTable extends Table with SupportsWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index e0091293d1669..71d1d5e527f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,20 +16,24 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util + import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.catalog.TableProvider +import org.apache.spark.sql.connector.catalog.{Table, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** * A base interface for data source v2 implementations of the built-in file-based data sources. */ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { + /** * Returns a V1 [[FileFormat]] class of the same file data source. * This is a solution for the following cases: @@ -41,7 +45,7 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { lazy val sparkSession = SparkSession.active - protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = { + protected def getPaths(map: java.util.Map[String, String]): Seq[String] = { val objectMapper = new ObjectMapper() val paths = Option(map.get("paths")).map { pathStr => objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq @@ -59,4 +63,12 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toString } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "file source v2 does not support user-specified partitioning yet.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 5329e09916bd6..30f3416892af4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -34,15 +34,20 @@ import org.apache.spark.sql.util.SchemaUtils abstract class FileTable( sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType]) extends Table with SupportsRead with SupportsWrite { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + private def caseSensitiveOptions = options match { + case m: CaseInsensitiveStringMap => m.asCaseSensitiveMap() + case other => other + } + lazy val fileIndex: PartitioningAwareFileIndex = { - val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + val caseSensitiveMap = caseSensitiveOptions.asScala.toMap // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) if (FileStreamSink.hasMetadata(paths, hadoopConf, sparkSession.sessionState.conf)) { @@ -104,7 +109,7 @@ abstract class FileTable( override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms - override def properties: util.Map[String, String] = options.asCaseSensitiveMap + override def properties: util.Map[String, String] = caseSensitiveOptions override def capabilities: java.util.Set[TableCapability] = FileTable.CAPABILITIES diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index dffb9cb67b5c2..e8eb868b59e74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -23,6 +23,7 @@ import java.util import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} @@ -61,6 +62,47 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) } } + override def tableExists(ident: Identifier): Boolean = { + if (ident.namespace().length <= 1) { + catalog.tableExists(ident.asTableIdentifier) + } else { + false + } + } + + private def tryResolveTableProvider(v1Table: V1Table): Table = { + val providerName = v1Table.catalogTable.provider.get + DataSource.lookupDataSourceV2(providerName, conf).map { + // TODO: support file source v2 in CREATE TABLE USING. + case _: FileDataSourceV2 => v1Table + + case provider => + val table = provider.getTable(v1Table.schema, v1Table.partitioning, v1Table.properties) + validateTableSchemaAndPartitioning( + providerName, table, v1Table.schema, v1Table.partitioning) + table + }.getOrElse(v1Table) + } + + private def validateTableSchemaAndPartitioning( + providerName: String, + table: Table, + metaStoreSchema: StructType, + metaStorePartitioning: Array[Transform]): Unit = { + if (table.schema() != metaStoreSchema) { + throw new AnalysisException(s"Table provider '$providerName' reports a different data " + + "schema from the one in Spark meta-store:\n" + + s"Schema in Spark meta-store: $metaStoreSchema\n" + + s"Actual data schema: ${table.schema}") + } + if (!table.partitioning().sameElements(metaStorePartitioning)) { + throw new AnalysisException(s"Table provider '$providerName' reports a different data " + + "partitioning from the one in Spark meta-store:\n" + + s"Partitioning in Spark meta-store:: ${metaStorePartitioning.mkString(", ")}\n" + + s"Actual data partitioning: ${table.partitioning().mkString(", ")}") + } + } + override def loadTable(ident: Identifier): Table = { val catalogTable = try { catalog.getTableMetadata(ident.asTableIdentifier) @@ -69,7 +111,11 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) throw new NoSuchTableException(ident) } - V1Table(catalogTable) + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw new NoSuchTableException(ident) + } + + tryResolveTableProvider(V1Table(catalogTable)) } override def invalidateTable(ident: Identifier): Unit = { @@ -82,8 +128,25 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) partitions: Array[Transform], properties: util.Map[String, String]): Table = { - val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(partitions) - val provider = properties.getOrDefault("provider", conf.defaultDataSourceName) + val providerName = properties.getOrDefault("provider", conf.defaultDataSourceName) + // It's guaranteed that we only call `V2SessionCatalog.createTable` if the table provider is v2. + val provider = DataSource.lookupDataSourceV2(providerName, conf).get + val (actualSchema, actualPartitioning) = if (schema.isEmpty) { + // A sanity check. The parser should guarantee it. + assert(partitions.isEmpty) + // If `CREATE TABLE ... USING` does not specify table metadata, get the table metadata from + // data source first. + val table = provider.getTable(new CaseInsensitiveStringMap(properties)) + table.schema() -> table.partitioning() + } else { + // The schema/partitioning is specified in `CREATE TABLE ... USING`, validate it. + val table = provider.getTable(schema, partitions, properties) + validateTableSchemaAndPartitioning( + providerName, table, schema, partitions) + schema -> partitions + } + + val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(actualPartitioning) val tableProperties = properties.asScala val location = Option(properties.get(LOCATION_TABLE_PROP)) val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) @@ -94,8 +157,8 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) identifier = ident.asTableIdentifier, tableType = tableType, storage = storage, - schema = schema, - provider = Some(provider), + schema = actualSchema, + provider = Some(providerName), partitionColumnNames = partitionColumns, bucketSpec = maybeBucketSpec, properties = tableProperties.toMap, @@ -136,19 +199,14 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) } override def dropTable(ident: Identifier): Boolean = { - try { - if (loadTable(ident) != null) { - catalog.dropTable( - ident.asTableIdentifier, - ignoreIfNotExists = true, - purge = true /* skip HDFS trash */) - true - } else { - false - } - } catch { - case _: NoSuchTableException => - false + if (tableExists(ident)) { + catalog.dropTable( + ident.asTableIdentifier, + ignoreIfNotExists = true, + purge = true /* skip HDFS trash */) + true + } else { + false } } @@ -156,9 +214,9 @@ class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) if (tableExists(newIdent)) { throw new TableAlreadyExistsException(newIdent) } - - // Load table to make sure the table exists - loadTable(oldIdent) + if (!tableExists(oldIdent)) { + throw new NoSuchTableException(oldIdent) + } catalog.renameTable(oldIdent.asTableIdentifier, newIdent.asTableIdentifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala index 1f99d4282f6da..ed2a8e50f04fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.csv +import java.util + import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -29,15 +31,15 @@ class CSVDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "csv" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + CSVTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - CSVTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + CSVTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index 04beee0e3b0f2..a067b6a370419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class CSVTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala index 7a0949e586cd8..5ab6024a27aee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.json +import java.util + import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -29,16 +31,16 @@ class JsonDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "json" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + JsonTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - JsonTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + JsonTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala index 9bb615528fc5d..518a9242464ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class JsonTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 8665af33b976a..c53b99170db78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.orc +import java.util + import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat @@ -29,16 +31,15 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + OrcTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - OrcTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + OrcTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index f2e4b88e9f1ae..ebbfa4feaa376 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala index 8cb6186c12ff3..d17252c9480ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.parquet +import java.util + import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -29,16 +31,16 @@ class ParquetDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + ParquetTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - ParquetTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + ParquetTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index 2ad64b1aa5244..5b8c3ac9d77bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala index 049c717effa26..aade932566620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextDataSourceV2.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.datasources.v2.text +import java.util + import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.text.TextFileFormat @@ -29,16 +31,15 @@ class TextDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "text" - override def getTable(options: CaseInsensitiveStringMap): Table = { - val paths = getPaths(options) + override def getTable(properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + TextTable(tableName, sparkSession, properties, paths, None, fallbackFileFormat) } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { - val paths = getPaths(options) + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + val paths = getPaths(properties) val tableName = getTableName(paths) - TextTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + TextTable(tableName, sparkSession, properties, paths, Some(schema), fallbackFileFormat) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala index 87bfa84985e5a..c21c1441afff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class TextTable( name: String, sparkSession: SparkSession, - options: CaseInsensitiveStringMap, + options: java.util.Map[String, String], paths: Seq[String], userSpecifiedSchema: Option[StructType], fallbackFileFormat: Class[_ <: FileFormat]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 20eb7ae5a6d96..61b6fba65026a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.write.{SupportsTruncate, WriteBuilder} import org.apache.spark.sql.connector.write.streaming.StreamingWrite import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite @@ -39,10 +40,23 @@ class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { ConsoleTable } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read console sink with user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Cannot read console sink with user-specified schema/partitioning.") + } + def createRelation( sqlContext: SQLContext, mode: SaveMode, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 911a526428cf4..f1615b1d518ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.internal.SQLConf @@ -95,9 +96,22 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "memory stream does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "memory stream does not support user-specified schema/partitioning.") + } } class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3f7b0377f1eab..df95bcde482e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream @@ -48,14 +49,14 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ - override def getTable(options: CaseInsensitiveStringMap): Table = { - val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) + override def getTable(properties: util.Map[String, String]): Table = { + val rowsPerSecond = Option(properties.get(ROWS_PER_SECOND)).map(_.toLong).getOrElse(1L) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") } - val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME)) + val rampUpTimeSeconds = Option(properties.get(RAMP_UP_TIME)) .map(JavaUtils.timeStringAsSec) .getOrElse(0L) if (rampUpTimeSeconds < 0) { @@ -63,8 +64,9 @@ class RateStreamProvider extends TableProvider with DataSourceRegister { s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative") } - val numPartitions = options.getInt( - NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism) + val numPartitions = Option(properties.get(NUM_PARTITIONS)).map(_.toInt).getOrElse { + SparkSession.active.sparkContext.defaultParallelism + } if (numPartitions <= 0) { throw new IllegalArgumentException( s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") @@ -72,6 +74,19 @@ class RateStreamProvider extends TableProvider with DataSourceRegister { new RateStreamTable(rowsPerSecond, rampUpTimeSeconds, numPartitions) } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Rate stream source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "Rate stream source does not support user-specified schema/partitioning.") + } + override def shortName(): String = "rate" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index fae3cb765c0c9..465e56c806017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousStream @@ -36,7 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: CaseInsensitiveStringMap): Unit = { + private def checkParameters(params: util.Map[String, String]): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") if (!params.containsKey("host")) { @@ -46,7 +47,7 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - params.getBoolean("includeTimestamp", false) + Option(params.get("includeTimestamp")).foreach(_.toBoolean) } match { case Success(_) => case Failure(_) => @@ -54,13 +55,28 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } } - override def getTable(options: CaseInsensitiveStringMap): Table = { - checkParameters(options) + override def getTable(properties: util.Map[String, String]): Table = { + checkParameters(properties) new TextSocketTable( - options.get("host"), - options.getInt("port", -1), - options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), - options.getBoolean("includeTimestamp", false)) + properties.get("host"), + Option(properties.get("port")).map(_.toInt).getOrElse(-1), + Option(properties.get("numPartitions")).map(_.toInt).getOrElse { + SparkSession.active.sparkContext.defaultParallelism + }, + Option(properties.get("includeTimestamp")).map(_.toBoolean).getOrElse(false)) + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "socket source does not support user-specified schema/partitioning.") + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException( + "socket source does not support user-specified schema/partitioning.") } /** String that represents the format that this data source provider uses. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4a6516d325ddd..796ab1997acb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -188,7 +188,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) + case Some(schema) => provider.getTable(schema, dsOptions) case _ => provider.getTable(dsOptions) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java index 9386ab51d64f0..14429a06ca0c9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2.java @@ -22,18 +22,18 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaAdvancedDataSourceV2 implements TableProvider { +public class JavaAdvancedDataSourceV2 extends TestingV2Source { @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 76da45e182b3c..e5c74b3679450 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -18,10 +18,11 @@ package test.org.apache.spark.sql.connector; import java.io.IOException; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -33,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; -public class JavaColumnarDataSourceV2 implements TableProvider { +public class JavaColumnarDataSourceV2 extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -52,7 +53,7 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java index fbbc457b2945d..556507e255c68 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaPartitionAwareDataSource.java @@ -19,20 +19,21 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.connector.read.partitioning.ClusteredDistribution; import org.apache.spark.sql.connector.read.partitioning.Distribution; import org.apache.spark.sql.connector.read.partitioning.Partitioning; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaPartitionAwareDataSource implements TableProvider { +public class JavaPartitionAwareDataSource extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportPartitioning { @@ -56,7 +57,7 @@ public Partitioning outputPartitioning() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public Transform[] partitioning() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java index 49438fe668d56..fd8819ae52e96 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaReportStatisticsDataSource.java @@ -17,17 +17,18 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; import java.util.OptionalLong; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.Statistics; import org.apache.spark.sql.connector.read.SupportsReportStatistics; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaReportStatisticsDataSource implements TableProvider { +public class JavaReportStatisticsDataSource extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { @Override public Statistics estimateStatistics() { @@ -54,7 +55,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java index 2181887ae54e2..9923b8c34353a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSchemaRequiredDataSource.java @@ -17,8 +17,11 @@ package test.org.apache.spark.sql.connector; +import java.util.Map; + import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.types.StructType; @@ -46,7 +49,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options, StructType schema) { + public Table getTable(StructType schema, Map properties) { return new JavaSimpleBatchTable() { @Override @@ -62,7 +65,16 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable( + StructType schema, + Transform[] partitioning, + Map properties) { + assert partitioning.length == 0; + return getTable(schema, properties); + } + + @Override + public Table getTable(Map properties) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java index 97b00477e1764..71cf97b56fe54 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleBatchTable.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.Set; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; @@ -34,7 +35,7 @@ abstract class JavaSimpleBatchTable implements Table, SupportsRead { @Override public StructType schema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java index 8b6d71b986ff7..32c23381100ce 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleDataSourceV2.java @@ -17,13 +17,15 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class JavaSimpleDataSourceV2 implements TableProvider { +import java.util.Map; + +public class JavaSimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends JavaSimpleScanBuilder { @@ -37,7 +39,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(CaseInsensitiveStringMap options) { + public Table getTable(Map properties) { return new JavaSimpleBatchTable() { @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java index 7cbba00420928..bdd9dd3ea0ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaSimpleScanBuilder.java @@ -17,6 +17,7 @@ package test.org.apache.spark.sql.connector; +import org.apache.spark.sql.connector.TestingV2Source; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.PartitionReaderFactory; import org.apache.spark.sql.connector.read.Scan; @@ -37,7 +38,7 @@ public Batch toBatch() { @Override public StructType readSchema() { - return new StructType().add("i", "int").add("j", "int"); + return TestingV2Source.schema(); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index e27575cecde25..cc3d5b48cb47d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -92,12 +92,6 @@ class DataSourceV2DataFrameSessionCatalogSuite } } -class InMemoryTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("D'oh!") - } -} - class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { override def newTable( name: String, @@ -131,7 +125,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable } } -private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] +private[connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -140,7 +134,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio spark.sessionState.catalogManager.catalog(name) } - protected val v2Format: String = classOf[InMemoryTableProvider].getName + protected val v2Format: String = classOf[FakeV2Provider].getName protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 3b42c2374f006..0f83fee94e886 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.connector +import java.util + import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.SimpleScanSource @@ -1175,7 +1178,18 @@ class DataSourceV2SQLSuite /** Used as a V2 DataSource for V2SessionCatalog DDL */ class FakeV2Provider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("Unnecessary for DDL tests") + override def getTable(properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException() + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException() + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala new file mode 100644 index 0000000000000..0a8c74aae4e80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLUsingSuite.scala @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.{LogicalExpressions, Transform} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class DataSourceV2SQLUsingSuite extends QueryTest with SharedSparkSession { + + override def beforeEach(): Unit = { + super.beforeEach() + ReadWriteV2Source.table.clear() + PartitionedV2Source.table.clear() + } + + test("basic") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + + val e = intercept[AnalysisException](sql("INSERT INTO t SELECT 1")) + assert(e.message.contains("not enough data columns")) + + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("CREATE TABLE with a mismatched schema") { + withTable("t") { + val e = intercept[AnalysisException]( + sql(s"CREATE TABLE t(i INT) USING ${classOf[ReadWriteV2Source].getName}") + ) + assert(e.getMessage.contains("reports a different data schema")) + + sql(s"CREATE TABLE t(i INT, j INT) USING ${classOf[ReadWriteV2Source].getName}") + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("CREATE TABLE with a mismatched partitioning") { + withTable("t") { + val e = intercept[AnalysisException]( + sql( + s""" + |CREATE TABLE t(i INT, j INT) USING ${classOf[PartitionedV2Source].getName} + |PARTITIONED BY (i) + """.stripMargin) + ) + assert(e.getMessage.contains("reports a different data partitioning")) + + sql( + s""" + |CREATE TABLE t(i INT, j INT) USING ${classOf[PartitionedV2Source].getName} + |PARTITIONED BY (j) + """.stripMargin) + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + } + } + + test("read-only table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadOnlyV2Source].getName}") + checkAnswer(spark.table("t"), Seq(Row(1), Row(-1))) + + val e1 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1, -1")) + assert(e1.message.contains("too many data columns")) + + val e2 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1")) + assert(e2.message.contains("does not support append in batch mode")) + } + } + + test("write-only table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[WriteOnlyV2Source].getName}") + + val e1 = intercept[AnalysisException](sql("INSERT INTO t SELECT 1, 1")) + assert(e1.message.contains("too many data columns")) + + sql("INSERT INTO t SELECT 1") + + val e2 = intercept[AnalysisException](sql("SELECT * FROM t").collect()) + assert(e2.message.contains("Table write-only does not support batch scan")) + } + } + + test("CREATE TABLE AS SELECT") { + withTable("t") { + sql( + s""" + |CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName} + |AS SELECT 1 AS i, -1 AS j + """.stripMargin) + checkAnswer(spark.table("t"), Row(1, -1)) + + sql("INSERT INTO t SELECT 2, -2") + checkAnswer(spark.table("t"), Seq(Row(1, -1), Row(2, -2))) + } + } + + test("INSERT OVERWRITE with non-partitioned table") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + + sql("INSERT OVERWRITE t SELECT 2, -2") + checkAnswer(spark.table("t"), Row(2, -2)) + } + } + + test("INSERT OVERWRITE with partitioned table (static mode)") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "static") { + sql("INSERT OVERWRITE t PARTITION(j=2) SELECT 4") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(4, 2))) + + sql("INSERT OVERWRITE t SELECT 0, 1") + checkAnswer(spark.table("t"), Row(0, 1)) + } + } + } + + test("INSERT OVERWRITE with partitioned table (dynamic mode)") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { + sql("INSERT OVERWRITE t PARTITION(j=2) SELECT 4") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(4, 2))) + + sql("INSERT OVERWRITE t SELECT 0, 1") + checkAnswer(spark.table("t"), Seq(Row(0, 1), Row(4, 2))) + } + } + } + + test("DELETE FROM") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[PartitionedV2Source].getName}") + sql("INSERT INTO t SELECT 1, 1") + sql("INSERT INTO t SELECT 2, 1") + sql("INSERT INTO t SELECT 3, 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1), Row(3, 2))) + + sql("DELETE FROM t WHERE j = 2") + checkAnswer(spark.table("t"), Seq(Row(1, 1), Row(2, 1))) + } + } + + test("ALTER TABLE") { + withTable("t") { + sql(s"CREATE TABLE t USING ${classOf[ReadWriteV2Source].getName}") + sql("INSERT INTO t SELECT 1, -1") + checkAnswer(spark.table("t"), Row(1, -1)) + + val e = intercept[AnalysisException](sql("ALTER TABLE t DROP COLUMN i")) + assert(e.message.contains("reports a different data schema")) + } + } + + test("ALTER TABLE with table supporting schema changing") { + withTable("t") { + sql(s"CREATE TABLE t(i INT) USING ${classOf[DynamicSchemaV2Source].getName}") + checkAnswer(spark.table("t"), Nil) + + sql("ALTER TABLE t ADD COLUMN j INT") + checkAnswer(spark.table("t"), Nil) + } + } +} + +object ReadWriteV2Source { + val table = { + val schema = new StructType().add("i", "int").add("j", "int") + val partitioning = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("read-write", schema, partitioning, properties) {} + } +} + +class ReadWriteV2Source extends TableProvider { + // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data + // of in-memory table, we keep the table instance in an object. + override def getTable(properties: util.Map[String, String]): Table = { + ReadWriteV2Source.table + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } +} + +object PartitionedV2Source { + val table = { + val schema = new StructType().add("i", "int").add("j", "int") + val partitioning = Array[Transform](LogicalExpressions.identity("j")) + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("read-write", schema, partitioning, properties) {} + } +} + +class PartitionedV2Source extends TableProvider { + // `TableProvider` will be instantiated by reflection every time it's accessed. To keep the data + // of in-memory table, we keep the table instance in an object. + override def getTable(properties: util.Map[String, String]): Table = { + PartitionedV2Source.table + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } +} + +class ReadOnlyV2Source extends TableProvider { + override def getTable(properties: util.Map[String, String]): Table = { + val schema = new StructType().add("i", "int") + val partitions = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + val table = new InMemoryTable("read-only", schema, partitions, properties) { + override def capabilities: util.Set[TableCapability] = { + Set(TableCapability.BATCH_READ).asJava + } + } + val rows = new BufferedRows() + rows.withRow(InternalRow(1)).withRow(InternalRow(-1)) + table.withData(Array(rows)) + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } +} + +class WriteOnlyV2Source extends TableProvider { + override def getTable(properties: util.Map[String, String]): Table = { + val schema = new StructType().add("i", "int") + val partitions = Array.empty[Transform] + val properties = util.Collections.emptyMap[String, String] + new InMemoryTable("write-only", schema, partitions, properties) { + override def capabilities: util.Set[TableCapability] = { + Set(TableCapability.BATCH_WRITE).asJava + } + } + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(CaseInsensitiveStringMap.empty()) + } +} + +class DynamicSchemaV2Source extends TableProvider { + override def getTable(properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", new StructType(), Array.empty, properties) {} + } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", schema, Array.empty, properties) {} + } + + override def getTable( + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + new InMemoryTable("dynamic-schema", schema, partitions, properties) {} + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 138bbc3f04f64..df0265a763ad6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} @@ -415,9 +416,29 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } +abstract class TestingV2Source extends TableProvider { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + assert(schema == TestingV2Source.schema) + getTable(new CaseInsensitiveStringMap(properties)) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + assert(schema == TestingV2Source.schema) + assert(partitioning.isEmpty) + getTable(new CaseInsensitiveStringMap(properties)) + } +} + +object TestingV2Source { + val schema = new StructType().add("i", "int").add("j", "int") +} + abstract class SimpleBatchTable extends Table with SupportsRead { - override def schema(): StructType = new StructType().add("i", "int").add("j", "int") + override def schema(): StructType = TestingV2Source.schema override def name(): String = this.getClass.toString @@ -431,12 +452,12 @@ abstract class SimpleScanBuilder extends ScanBuilder override def toBatch: Batch = this - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def readSchema(): StructType = TestingV2Source.schema override def createReaderFactory(): PartitionReaderFactory = SimpleReaderFactory } -class SimpleSinglePartitionSource extends TableProvider { +class SimpleSinglePartitionSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -444,7 +465,7 @@ class SimpleSinglePartitionSource extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -453,7 +474,7 @@ class SimpleSinglePartitionSource extends TableProvider { // This class is used by pyspark tests. If this class is modified/moved, make sure pyspark // tests still pass. -class SimpleDataSourceV2 extends TableProvider { +class SimpleDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { override def planInputPartitions(): Array[InputPartition] = { @@ -461,16 +482,16 @@ class SimpleDataSourceV2 extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } } -class AdvancedDataSourceV2 extends TableProvider { +class AdvancedDataSourceV2 extends TestingV2Source { - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } @@ -566,11 +587,13 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + override def getTable( + schema: StructType, + properties: util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema @@ -580,9 +603,16 @@ class SchemaRequiredDataSource extends TableProvider { } } } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(schema, properties) + } } -class ColumnarDataSourceV2 extends TableProvider { +class ColumnarDataSourceV2 extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder { @@ -595,7 +625,7 @@ class ColumnarDataSourceV2 extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -647,7 +677,7 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } } -class PartitionAwareDataSource extends TableProvider { +class PartitionAwareDataSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder with SupportsReportPartitioning{ @@ -666,7 +696,7 @@ class PartitionAwareDataSource extends TableProvider { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def getTable(properties: util.Map[String, String]): Table = new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } @@ -706,8 +736,8 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def getTable(options: CaseInsensitiveStringMap): Table = { - new MyTable(options) { + override def getTable(properties: util.Map[String, String]): Table = { + new MyTable(properties) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") } @@ -715,7 +745,7 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { } } -class ReportStatisticsDataSource extends TableProvider { +class ReportStatisticsDataSource extends TestingV2Source { class MyScanBuilder extends SimpleScanBuilder with SupportsReportStatistics { @@ -732,7 +762,7 @@ class ReportStatisticsDataSource extends TableProvider { } } - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new SimpleBatchTable { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index 2b3340527a4e2..765ad0858805f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.connector +import java.util + import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -40,9 +42,13 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new DummyReadOnlyFileTable } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException + } } class DummyReadOnlyFileTable extends Table with SupportsRead { @@ -64,9 +70,13 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new DummyWriteOnlyFileTable } + + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException + } } class DummyWriteOnlyFileTable extends Table with SupportsWrite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index 22d3750022c57..3ee6c0f780f03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TestingV2Source with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") @@ -131,10 +131,10 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } } - class MyTable(options: CaseInsensitiveStringMap) + class MyTable(properties: util.Map[String, String]) extends SimpleBatchTable with SupportsWrite { - private val path = options.get("path") + private val path = properties.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration override def schema(): StructType = tableSchema @@ -151,8 +151,8 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } - override def getTable(options: CaseInsensitiveStringMap): Table = { - new MyTable(options) + override def getTable(properties: util.Map[String, String]): Table = { + new MyTable(properties) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce6d56cf84df1..99bff90700749 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -39,7 +39,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { StreamingRelationV2( - TestTableProvider, + new FakeV2Provider, "fake", table, CaseInsensitiveStringMap.empty(), @@ -203,12 +203,6 @@ private case object TestRelation extends LeafNode with NamedRelation { override def output: Seq[AttributeReference] = TableCapabilityCheckSuite.schema.toAttributes } -private object TestTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException - } -} - private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index de843ba4375d0..de6b985cc7cdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -176,18 +176,29 @@ class InMemoryV1Provider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - InMemoryV1Provider.tables.getOrElse(options.get("name"), { + override def getTable(properties: util.Map[String, String]): Table = { + + InMemoryV1Provider.tables.getOrElse(properties.get("name"), { new InMemoryTableWithV1Fallback( "InMemoryTableWithV1Fallback", new StructType(), Array.empty, - options.asCaseSensitiveMap() - ) + properties) }) } + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(new CaseInsensitiveStringMap(properties)) + } + override def shortName(): String = "in-memory" override def createRelation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 70b1db8e5f0d2..0a0b79afbff67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2671,7 +2671,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val e = intercept[AnalysisException] { sql("ALTER TABLE v1 ADD COLUMNS (c3 INT)") } - assert(e.message.contains("ALTER ADD COLUMNS does not support views")) + assert(e.message.contains("'v1' is a view not a table")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 674efa9b8ba42..e97e3cf18d843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, import org.apache.spark.sql.catalyst.expressions.{EqualTo, IntegerLiteral, StringLiteral} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateTableAsSelect, CreateV2Table, DescribeTable, DropTable, LogicalPlan, SubqueryAlias, UpdateTable} -import org.apache.spark.sql.connector.InMemoryTableProvider +import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ - private val v2Format = classOf[InMemoryTableProvider].getName + private val v2Format = classOf[FakeV2Provider].getName private val table: Table = { val t = mock(classOf[Table]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index f791ab66e86fa..e82c4518cce93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -200,10 +200,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSparkSession { StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) + provider.getTable(userSpecifiedSchema, params.asJava) } - assert(exception.getMessage.contains( - "TextSocketSourceProvider source does not support user-specified schema")) + assert(exception.getMessage.contains("socket source does not support user-specified schema")) } test("input row metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index e9d148c38e6cb..6eb7933c638e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.connector.catalog.{SessionConfigSupport, SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReaderFactory, ContinuousStream, MicroBatchStream, Offset, PartitionOffset} import org.apache.spark.sql.connector.write.{WriteBuilder, WriterCommitMessage} @@ -90,16 +91,29 @@ trait FakeStreamingWriteTable extends Table with SupportsWrite { } } +trait FakeTableProvider extends TableProvider { + override def getTable(schema: StructType, properties: util.Map[String, String]): Table = { + getTable(properties) + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + getTable(properties) + } +} + class FakeReadMicroBatchOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-microbatch-only" override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastReadOptions.options = properties new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -115,14 +129,14 @@ class FakeReadMicroBatchOnly class FakeReadContinuousOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-read-continuous-only" override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastReadOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastReadOptions.options = properties new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -136,10 +150,10 @@ class FakeReadContinuousOnly } } -class FakeReadBothModes extends DataSourceRegister with TableProvider { +class FakeReadBothModes extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table with SupportsRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -153,10 +167,10 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { } } -class FakeReadNeitherMode extends DataSourceRegister with TableProvider { +class FakeReadNeitherMode extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-read-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -167,14 +181,14 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { class FakeWriteOnly extends DataSourceRegister - with TableProvider + with FakeTableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" override def keyPrefix: String = shortName() - override def getTable(options: CaseInsensitiveStringMap): Table = { - LastWriteOptions.options = options + override def getTable(properties: util.Map[String, String]): Table = { + LastWriteOptions.options = properties new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -182,9 +196,9 @@ class FakeWriteOnly } } -class FakeNoWrite extends DataSourceRegister with TableProvider { +class FakeNoWrite extends DataSourceRegister with FakeTableProvider { override def shortName(): String = "fake-write-neither-mode" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -200,7 +214,7 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with TableProvider with StreamSinkProvider { + with FakeTableProvider with StreamSinkProvider { override def createSink( sqlContext: SQLContext, @@ -212,7 +226,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" - override def getTable(options: CaseInsensitiveStringMap): Table = { + override def getTable(properties: util.Map[String, String]): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -221,7 +235,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister } object LastReadOptions { - var options: CaseInsensitiveStringMap = _ + var options: util.Map[String, String] = _ def clear(): Unit = { options = null @@ -229,7 +243,7 @@ object LastReadOptions { } object LastWriteOptions { - var options: CaseInsensitiveStringMap = _ + var options: util.Map[String, String] = _ def clear(): Unit = { options = null @@ -347,7 +361,7 @@ class StreamingDataSourceV2Suite extends StreamTest { eventually(timeout(streamingTimeout)) { // Write options should not be set. assert(!LastWriteOptions.options.containsKey(readOptionName)) - assert(LastReadOptions.options.getBoolean(readOptionName, false)) + assert(LastReadOptions.options.get(readOptionName) == "true") } } } @@ -358,7 +372,7 @@ class StreamingDataSourceV2Suite extends StreamTest { eventually(timeout(streamingTimeout)) { // Read options should not be set. assert(!LastReadOptions.options.containsKey(writeOptionName)) - assert(LastWriteOptions.options.getBoolean(writeOptionName, false)) + assert(LastWriteOptions.options.get(writeOptionName) == "true") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 4253fe2e1edcb..406e41db791bc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -769,6 +769,11 @@ class HiveDDLSuite } } + private def assertErrorForProcessViewAsTable(sqlText: String): Unit = { + val message = intercept[AnalysisException](sql(sqlText)).getMessage + assert(message.contains("Invalid command")) + } + private def assertErrorForAlterTableOnView(sqlText: String): Unit = { val message = intercept[AnalysisException](sql(sqlText)).getMessage assert(message.contains("Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) @@ -831,13 +836,13 @@ class HiveDDLSuite assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") assertErrorForAlterViewOnTable(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") - assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") + assertErrorForProcessViewAsTable(s"ALTER TABLE $oldViewName SET LOCATION '/path/to/home'") assertErrorForAlterTableOnView(s"ALTER TABLE $oldViewName SET SERDE 'whatever'")